mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
293 Commits
feat/depth
...
feat/smolv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f1541243a | ||
|
|
d70c810416 | ||
|
|
4c3ddb1ff5 | ||
|
|
8615f3f613 | ||
|
|
2686450d68 | ||
|
|
4913356564 | ||
|
|
673cc6b0fe | ||
|
|
2ed6519a93 | ||
|
|
72ea531017 | ||
|
|
56a934ec55 | ||
|
|
738e317caa | ||
|
|
8ba3b187a1 | ||
|
|
057c794ffe | ||
|
|
b1e83f556c | ||
|
|
da3e87ee86 | ||
|
|
1e9a6d044d | ||
|
|
3fdfcb912a | ||
|
|
c37b1fc7d0 | ||
|
|
9020635b14 | ||
|
|
83d0c390da | ||
|
|
1ff10b935c | ||
|
|
67bdf4690e | ||
|
|
8085feab6e | ||
|
|
a088c10c80 | ||
|
|
9c3d5ab7ce | ||
|
|
e84f97a8c1 | ||
|
|
6d2b8c80ab | ||
|
|
793c7c4ddd | ||
|
|
db927ab40b | ||
|
|
471b2b1b1d | ||
|
|
a15e16c072 | ||
|
|
336af85c09 | ||
|
|
54221ceea2 | ||
|
|
369ab17110 | ||
|
|
86a7edc590 | ||
|
|
77a16db529 | ||
|
|
8194897994 | ||
|
|
ca1b951e7b | ||
|
|
9d30d91021 | ||
|
|
9f437d86b6 | ||
|
|
b74a551d38 | ||
|
|
c0a2e9814d | ||
|
|
e050d0fe0a | ||
|
|
2ca030fa28 | ||
|
|
36f828221c | ||
|
|
d41d874581 | ||
|
|
bac4f61eae | ||
|
|
efa05f0ada | ||
|
|
e98b6f726b | ||
|
|
f7747d02a9 | ||
|
|
86ecd4bc2e | ||
|
|
28b86449a2 | ||
|
|
f4b834844e | ||
|
|
5bb2da4da6 | ||
|
|
f7b989ad97 | ||
|
|
3b4376aa33 | ||
|
|
a0233f53f4 | ||
|
|
34269a5d78 | ||
|
|
75507491bf | ||
|
|
88519cb14c | ||
|
|
bc0c993b25 | ||
|
|
ddf4bc2063 | ||
|
|
b7317b6c29 | ||
|
|
c026aed8f8 | ||
|
|
e425dfd624 | ||
|
|
15f79b5e5e | ||
|
|
dfdc48a7f1 | ||
|
|
6a8878a639 | ||
|
|
d38eb89f71 | ||
|
|
7ab4936b1b | ||
|
|
2ea0da2d9f | ||
|
|
725ac95b0d | ||
|
|
7b64e5498d | ||
|
|
134a707c7a | ||
|
|
182f10184f | ||
|
|
ca8c60a0ed | ||
|
|
ce47075d6b | ||
|
|
26013da699 | ||
|
|
bb31988915 | ||
|
|
2629175d2d | ||
|
|
2b4c5f49e3 | ||
|
|
22c9c4905e | ||
|
|
7960cc14ec | ||
|
|
3c15fd8537 | ||
|
|
1750a87104 | ||
|
|
0e2dc1b76f | ||
|
|
474c5478d9 | ||
|
|
f72b28738a | ||
|
|
1bd53cc7da | ||
|
|
0f5f0e4091 | ||
|
|
7128bb1769 | ||
|
|
426d48dbbf | ||
|
|
fbcb9225f5 | ||
|
|
31e0c15e55 | ||
|
|
c5676ef1b3 | ||
|
|
b319ccf688 | ||
|
|
3174e14bc0 | ||
|
|
dc530e10fe | ||
|
|
e7c5613a39 | ||
|
|
516ffc7687 | ||
|
|
7a68bf13d9 | ||
|
|
15229468d0 | ||
|
|
a9cea3e8dd | ||
|
|
89d4846590 | ||
|
|
9dfc9084e1 | ||
|
|
fd18beb3a1 | ||
|
|
26cb38a7d0 | ||
|
|
bfb8cfb432 | ||
|
|
5e3b9ba82c | ||
|
|
083d3cd419 | ||
|
|
bf996c7938 | ||
|
|
0d88eaf8eb | ||
|
|
3cd348ffe2 | ||
|
|
db03fc6dc4 | ||
|
|
56068d37ea | ||
|
|
e727688052 | ||
|
|
f1a0a663cc | ||
|
|
6e64c20cf1 | ||
|
|
b29cccb37e | ||
|
|
f161e27e96 | ||
|
|
d5f293a1c9 | ||
|
|
95033733fc | ||
|
|
c3503b774f | ||
|
|
99ebee4d16 | ||
|
|
a8ca5128b8 | ||
|
|
dd97c33814 | ||
|
|
fa45ba631b | ||
|
|
ffd8c92ce5 | ||
|
|
841d3c47e1 | ||
|
|
2c920ab178 | ||
|
|
9f630e2a41 | ||
|
|
7a32f8a72a | ||
|
|
129aa207e3 | ||
|
|
e3ad1c59fc | ||
|
|
9ff62cb08c | ||
|
|
b2aa372fcf | ||
|
|
058b8f3958 | ||
|
|
b873fe454c | ||
|
|
83d7250a22 | ||
|
|
35f9063a6c | ||
|
|
17c0800461 | ||
|
|
c8763e0ad5 | ||
|
|
0f4faddc01 | ||
|
|
8dc0af3c28 | ||
|
|
8eba704f15 | ||
|
|
ecbac17196 | ||
|
|
12cce8f2cc | ||
|
|
ef5879a02a | ||
|
|
1d24301b67 | ||
|
|
3a20ea337e | ||
|
|
b6fb536460 | ||
|
|
bfd3bb1791 | ||
|
|
4908433f9a | ||
|
|
6ce1f36002 | ||
|
|
731576be80 | ||
|
|
47fb8318b1 | ||
|
|
53172873e3 | ||
|
|
fcdae0ce8e | ||
|
|
4852b9f952 | ||
|
|
0410705aff | ||
|
|
398a8cf730 | ||
|
|
ab5c1dc392 | ||
|
|
1292304c42 | ||
|
|
b95eebff77 | ||
|
|
fbcac95662 | ||
|
|
b9db4d21a2 | ||
|
|
aecb80a9d2 | ||
|
|
c98c695127 | ||
|
|
d528078aca | ||
|
|
a648da0455 | ||
|
|
d866c2c9fd | ||
|
|
01e2228b24 | ||
|
|
c36de3a3e8 | ||
|
|
cbfaf2c544 | ||
|
|
d0278ea093 | ||
|
|
15f6b08b0e | ||
|
|
fc715db4a3 | ||
|
|
fe4bd2b6ba | ||
|
|
3f7436ff8a | ||
|
|
992d13d4e9 | ||
|
|
afe40a016b | ||
|
|
41095e3cc3 | ||
|
|
e0fa957569 | ||
|
|
c661d81409 | ||
|
|
965d42825f | ||
|
|
1238a0cd47 | ||
|
|
53c7641885 | ||
|
|
088c8371df | ||
|
|
3a52a18b0e | ||
|
|
dad2cf1178 | ||
|
|
bce5387e04 | ||
|
|
85576acc29 | ||
|
|
e7e5fca5de | ||
|
|
beb22afd81 | ||
|
|
33a4b4a5a0 | ||
|
|
d55b581ca1 | ||
|
|
24d2ffe3c6 | ||
|
|
789f29aa56 | ||
|
|
a356b12c41 | ||
|
|
e8327b8e62 | ||
|
|
c450298147 | ||
|
|
5c30b14929 | ||
|
|
a764c3e1d6 | ||
|
|
b416f287f2 | ||
|
|
aa749d4947 | ||
|
|
1394a6ab5d | ||
|
|
db9118f16f | ||
|
|
7a945d7bdc | ||
|
|
a47e535b02 | ||
|
|
6d9b431b54 | ||
|
|
347e706326 | ||
|
|
fa8ae1e89b | ||
|
|
3ff6c6860e | ||
|
|
fd89efb545 | ||
|
|
2776b57c9e | ||
|
|
0fb5f04965 | ||
|
|
7296ac97af | ||
|
|
9cbbcfb6a2 | ||
|
|
fea41b29f5 | ||
|
|
7b4d281ef5 | ||
|
|
29bb8bb20e | ||
|
|
3fe686ce9f | ||
|
|
a1b8134ef1 | ||
|
|
8fa8323c91 | ||
|
|
5f7c6ba61d | ||
|
|
223cc8a9e2 | ||
|
|
af6d8ebd5b | ||
|
|
37b1eb218a | ||
|
|
52e1fd35cb | ||
|
|
7459dfccb6 | ||
|
|
73740ecf4b | ||
|
|
1b81e49214 | ||
|
|
d813c75b76 | ||
|
|
3434d2ef22 | ||
|
|
b71e10da6b | ||
|
|
0f6e3230df | ||
|
|
2f2e42c4aa | ||
|
|
5ee0104739 | ||
|
|
e064cfcb04 | ||
|
|
b3d9494831 | ||
|
|
1217fdb6f0 | ||
|
|
d0388e1142 | ||
|
|
524aa59faa | ||
|
|
27f7829b09 | ||
|
|
7f8bf108e8 | ||
|
|
855ff027f8 | ||
|
|
3b797bb118 | ||
|
|
aea04721ae | ||
|
|
ab5479129a | ||
|
|
e6d4ac6f02 | ||
|
|
5722d365c5 | ||
|
|
3d7e60cee4 | ||
|
|
7b767d4d60 | ||
|
|
f1e3ab7794 | ||
|
|
585341ba9f | ||
|
|
23ff346027 | ||
|
|
3c5cbe7af4 | ||
|
|
f2cbd97635 | ||
|
|
c06c8d594a | ||
|
|
cd495a3a9d | ||
|
|
c99ac45cd1 | ||
|
|
13aaafeae0 | ||
|
|
2129648bf4 | ||
|
|
f5cd3f6e4e | ||
|
|
ecf5766301 | ||
|
|
11597d4f71 | ||
|
|
8b9c598cf4 | ||
|
|
b325475b38 | ||
|
|
ef137ff86a | ||
|
|
c5df821a96 | ||
|
|
7ec3d7999c | ||
|
|
712d63abbd | ||
|
|
6653999983 | ||
|
|
4bdbedc9a0 | ||
|
|
e240305e8e | ||
|
|
ccd189b264 | ||
|
|
ef1242bbd4 | ||
|
|
ebf4a04d41 | ||
|
|
4419b4ef1b | ||
|
|
ff06ca82d2 | ||
|
|
fcb01e73eb | ||
|
|
268f8d1f53 | ||
|
|
663fff0ae2 | ||
|
|
9d6af804bf | ||
|
|
f763f85213 | ||
|
|
e3e9374e2c | ||
|
|
c1a0c601e2 | ||
|
|
1ca38d9748 | ||
|
|
5a6aa64570 | ||
|
|
0b06790da0 | ||
|
|
b43dc39ba4 | ||
|
|
2b71221194 | ||
|
|
8833d735a1 |
6
Makefile
6
Makefile
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -39,8 +39,12 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
@@ -143,6 +147,8 @@
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
- local: rebot_b601
|
||||
title: reBot B601-DM
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
|
||||
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
```
|
||||
|
||||
199
docs/source/annotation_pipeline.mdx
Normal file
199
docs/source/annotation_pipeline.mdx
Normal file
@@ -0,0 +1,199 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` populates the two language columns introduced by the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — directly into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
A vocabulary-discovery phase derives a small canonical wording, then three
|
||||
modules write into a per-episode staging tree, then a single writer
|
||||
rewrites the data shards in place:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | -------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of canonical task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections`|
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections`|
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
The `plan` module is constrained to a **canonical vocabulary** discovered
|
||||
once per dataset by the `vocabulary` module (phase 0). It watches a few
|
||||
sample episode videos (`--vocabulary.sample_episodes`, default `3`) and
|
||||
asks the VLM to derive a small set of imperative subtask labels and
|
||||
first-person memory milestones that recur across the demos. The VLM
|
||||
picks the right number of entries itself based on what it sees in the
|
||||
clips — short pick-and-place demos get ~6 subtask labels, longer
|
||||
multi-step recipes get more. The result lands at
|
||||
`meta/canonical_vocabulary.json` (human-readable / hand-editable) and
|
||||
is reused on every subsequent run. The `plan` module then constrains
|
||||
both subtask + memory generation to those exact strings — the
|
||||
downstream low-level policy sees a small, repeatable target
|
||||
distribution instead of thousands of LLM paraphrases. Disable with
|
||||
`--vocabulary.enabled=False` to fall back to free-form generation.
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet — the tool
|
||||
catalog lives at `meta/info.json["tools"]` instead (see
|
||||
[Tools](./tools)). After every annotation run the pipeline ensures the
|
||||
canonical `say` schema is present in that list, preserving any tools the
|
||||
user pre-declared.
|
||||
|
||||
If you want to declare additional tools for a dataset before annotation
|
||||
runs, edit `meta/info.json["tools"]` directly — the pipeline preserves
|
||||
anything already there. Implementations of those tools live under
|
||||
`src/lerobot/tools/`; one file per tool, registered via
|
||||
`TOOL_REGISTRY`. See the [Tools](./tools) doc for the authoring guide.
|
||||
|
||||
## Running locally
|
||||
|
||||
Install the extra and invoke the console script. Episode-level
|
||||
concurrency comes from `--executor.episode_parallelism` (default 16);
|
||||
that is the only knob the in-process executor exposes.
|
||||
|
||||
```bash
|
||||
uv sync --extra annotations
|
||||
uv run lerobot-annotate \
|
||||
--root=/path/to/dataset \
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
The pipeline attaches actual camera footage to every `plan` /
|
||||
`interjections` / `vqa` prompt by default, decoded from the dataset's
|
||||
first `observation.images.*` stream. Override with
|
||||
`--vlm.camera_key=observation.images.<name>` to pin a specific
|
||||
viewpoint. Datasets with no video tracks fall back to text-only prompts
|
||||
automatically.
|
||||
|
||||
**The `plan` module sees the whole episode as one video block.** Subtask
|
||||
decomposition gets a `{"type":"video", "video":[<frames>]}` block
|
||||
covering the entire demonstration; Qwen-VL pools temporally on its own
|
||||
and decides where to cut. There is no keyframe stride or count knob —
|
||||
`--plan.max_video_frames` (default 128) only caps the frames packed
|
||||
into the video block as a model-capacity bound. The `interjections`
|
||||
module attaches a short window of frames straddling the interjection
|
||||
timestamp. The `vqa` module grounds each VQA pair on a single frame —
|
||||
its `--vqa.K` knob sets how many consecutive frames each emission tick
|
||||
anchors, and every anchored frame gets its own VQA pair on that one
|
||||
frame (there is no per-pair frame window).
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Distributed annotation is delegated to
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs). The repo
|
||||
ships a launcher script you copy and edit for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`examples/annotations/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
spawns one `h200x2` job that:
|
||||
|
||||
1. installs the branch under test plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) for the chosen model,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
via `lerobot-annotate`,
|
||||
4. uploads the annotated dataset to `--push_to_hub`.
|
||||
|
||||
To target a different dataset, model, or hub repo, edit the `CMD` block
|
||||
inside the script — every flag in there maps directly onto a CLI flag of
|
||||
`lerobot-annotate` (see `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Style-to-recipe consumer mapping
|
||||
|
||||
The pipeline's outputs are designed to be consumed by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)) — for the
|
||||
canonical PI052 blend `src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml`:
|
||||
|
||||
- low-level / high-level / memory-update branches consume
|
||||
`subtask`/`plan`/`memory` from `language_persistent`.
|
||||
- An interjection-response branch consumes `interjection` events plus
|
||||
the paired speech atom (merged into one assistant target turn via
|
||||
`tool_calls_from`) and the same-timestamp `plan` refresh.
|
||||
- A VQA branch consumes the `(vqa, user)` and `(vqa, assistant)` pairs
|
||||
from `language_events`.
|
||||
|
||||
## Why the design splits state from events
|
||||
|
||||
Two things drive the scope:
|
||||
|
||||
1. **Persistent state vs exact-event split.** Persistent rows
|
||||
(`subtask`, `plan`, `memory`) broadcast per episode and answer "what
|
||||
state is in force at this frame?". Event rows (`interjection`, `vqa`,
|
||||
speech) only appear on the exact frame whose timestamp matches the
|
||||
emission. The pipeline writes timestamps taken straight from the
|
||||
source parquet — no floating-point recomputation.
|
||||
2. **One Qwen-VL pass.** All three modules share a single VLM client
|
||||
(vLLM if available, transformers fallback) so the cost is one model
|
||||
load per dataset, not three.
|
||||
|
||||
## Module independence and staged reruns
|
||||
|
||||
Each module writes its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
|
||||
prompt iteration cheap — re-running one module overwrites only its own
|
||||
JSONL file before the writer composes the final parquet. Modules can be
|
||||
disabled via `--plan.enabled=false` (and likewise `--interjections.enabled`
|
||||
/ `--vqa.enabled`) to
|
||||
test them in isolation.
|
||||
|
||||
## Validation/report checks before final write
|
||||
|
||||
Before the writer runs, `StagingValidator` checks:
|
||||
|
||||
- exact frame-timestamp alignment for every event row;
|
||||
- no orphan speech / interjection pairs;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (warning, not error);
|
||||
- VQA assistant `content` parses as JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row routes to the column dictated by `column_for_style(style)`.
|
||||
|
||||
Errors abort the writer (`--skip_validation=true` overrides for debugging).
|
||||
|
||||
## Paper inspirations per module
|
||||
|
||||
- **`plan` module — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
atom granularity ("pick up one piece of lettuce", "place bowl to box");
|
||||
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
|
||||
what" detail.
|
||||
- **`plan` module — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||
compression directive: keep only minimal relevant information; functional
|
||||
outcomes preserved, specific attributes dropped.
|
||||
- **`interjections` module.** Hi Robot scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
|
||||
arguments:{text:...}}}]`).
|
||||
- **`vqa` module.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies ([Zhao 2025](https://arxiv.org/abs/2509.07626))
|
||||
multi-abstraction grounding. Pi0.7 also grounds answers across
|
||||
multiple abstraction levels.
|
||||
|
||||
Future maintainers should adjust the prompt templates in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
|
||||
references rather than rewriting from scratch.
|
||||
|
||||
## Compute and list-size estimates
|
||||
|
||||
Per episode, the pipeline issues O(`max_steps`) `plan`-module calls,
|
||||
O(`max_interjections_per_episode`) `interjections`-module calls, and
|
||||
O(`vqa_emission_hz × episode_seconds`) `vqa`-module calls. With defaults
|
||||
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
|
||||
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
|
||||
KB at most (parquet dictionary-encodes one entry per episode);
|
||||
`language_events` is empty on most frames and is bounded by the number of
|
||||
emissions, not `num_frames × num_emissions`.
|
||||
|
||||
## Reproducibility via seed and prompt hashes
|
||||
|
||||
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
|
||||
timestamps and VQA question types. Combined with the deterministic prompt
|
||||
templates checked into `prompts/`, two runs at the same seed against the
|
||||
same dataset and the same model checkpoint produce byte-identical staging
|
||||
artifacts. Prompt edits are recorded by file hash; future tooling can pin
|
||||
expected `(seed, prompt_hash)` pairs into the dataset card.
|
||||
@@ -1,277 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -119,14 +121,12 @@ lerobot-record \
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
--duration=600
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -202,10 +216,11 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
152
docs/source/language_and_recipes.mdx
Normal file
152
docs/source/language_and_recipes.mdx
Normal file
@@ -0,0 +1,152 @@
|
||||
# Language columns and recipes
|
||||
|
||||
Most LeRobot datasets ship with a single `task` string per episode — fine for
|
||||
short, single-instruction skills, but not enough for the longer-horizon,
|
||||
multi-modal robot policies the field is moving toward (high-level planning,
|
||||
memory, interjections, VQA, tool use). To support those policies without
|
||||
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
|
||||
language columns and a small recipe layer that turns those rows into
|
||||
chat-style training samples on the fly.
|
||||
|
||||
The design splits cleanly into three layers:
|
||||
|
||||
1. **Data in the dataset** — language annotations stored next to frames in
|
||||
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
|
||||
and `language_events`). Datasets without these columns keep their existing
|
||||
behavior.
|
||||
2. **Recipe** — a YAML file that declares which annotation rows to bind and
|
||||
how to lay them out as chat turns (`role`, `content`, optional images,
|
||||
optional tool calls). Recipes are pure config; no Python required to add a
|
||||
new one.
|
||||
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
|
||||
recipe against the per-frame annotations and emits HF-style `messages` plus
|
||||
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
|
||||
that policy processors consume.
|
||||
|
||||
This page describes each layer in turn.
|
||||
|
||||
## Layer 1 — language columns in the dataset
|
||||
|
||||
The two optional columns live next to frame data in
|
||||
`data/chunk-*/file-*.parquet`:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float32 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||
the matching `observation.images.*` feature key. Rows of every other style —
|
||||
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||
enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
### Architecture
|
||||
|
||||
The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
### View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${vqa}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
if_present: vqa,
|
||||
}
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Blends
|
||||
|
||||
Blend recipes select one weighted sub-recipe deterministically from the sample index.
|
||||
`recipes/subtasks_vqa.yaml` trains the core blend — high-level subtask prediction, low-level execution, and VQA. `recipes/subtask_mem_vqa_speech.yaml` is the fuller variant that also adds memory updates and spoken interjection responses.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
186
docs/source/rebot_b601.mdx
Normal file
186
docs/source/rebot_b601.mdx
Normal file
@@ -0,0 +1,186 @@
|
||||
# reBot B601-DM
|
||||
|
||||
[reBot B601-DM](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/) is an open-source, low-cost robot arm from Seeed Studio for embodied-AI and imitation learning. It comes as a **follower** arm (the `B601-DM`, a 6-DOF arm plus gripper driven by Damiao CAN motors) and a **leader** arm (the `StarArm102` / `reBot Arm 102`, driven by FashionStar UART smart servos) used to teleoperate it.
|
||||
|
||||
This page covers **calibration** and **teleoperation** for both single-arm and bimanual (dual-arm) setups.
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/b601dm_zeroposition.jpg"
|
||||
alt="reBot B601-DM follower arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/102_zeroposition.jpg"
|
||||
alt="reBot Arm 102 leader arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
_Left: the B601-DM follower at its zero position. Right: the reBot Arm 102 leader at its zero position. Images courtesy of [Seeed Studio](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/)._
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the reBot support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[rebot]"
|
||||
```
|
||||
|
||||
This pulls in `motorbridge` (CAN motor control for the B601-DM follower) and `motorbridge-smart-servo` (FashionStar UART servos for the reBot Arm 102 leader).
|
||||
|
||||
## Registered device types
|
||||
|
||||
| Type | Kind |
|
||||
| ------------------------ | -------------------------------------------- |
|
||||
| `rebot_b601_follower` | single-arm B601-DM follower robot |
|
||||
| `bi_rebot_b601_follower` | bimanual (dual-arm) follower robot |
|
||||
| `rebot_102_leader` | single-arm reBot Arm 102 leader teleoperator |
|
||||
| `bi_rebot_102_leader` | bimanual (dual-arm) leader teleoperator |
|
||||
|
||||
The bimanual types compose two single-arm instances and namespace each arm's
|
||||
observation/action keys with a `left_` / `right_` prefix. Per-arm settings are
|
||||
passed through nested `left_arm_config.*` / `right_arm_config.*` arguments.
|
||||
|
||||
## Find the USB ports
|
||||
|
||||
For each device, find the USB port associated with its motor bus using:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
On Linux, remove `brltty` (`sudo apt remove brltty`) so it does not hold the
|
||||
leader's USB serial port. You may also need to grant access to the serial
|
||||
devices: `sudo chmod 666 /dev/ttyACM* /dev/ttyUSB*`.
|
||||
</Tip>
|
||||
|
||||
## Calibration
|
||||
|
||||
Neither arm stores a persistent hardware calibration: every time it connects, the motors are re-zeroed against the pose the arm is physically holding. Calibration simply records that zero pose. When prompted, **manually move the arm to its zero position** (the default sit-down pose shown above, gripper fully closed) and press <kbd>ENTER</kbd>.
|
||||
|
||||
### Follower (B601-DM)
|
||||
|
||||
<hfoptions id="calibrate-follower">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
Connect the bimanual follower; calibration runs for the left arm, then the right arm.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao
|
||||
```
|
||||
|
||||
Per-arm calibration files are saved with `_left` / `_right` suffixes on the id.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader (reBot Arm 102)
|
||||
|
||||
<hfoptions id="calibrate-leader">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperation
|
||||
|
||||
Once both arms are calibrated, drive the follower with the leader. The follower talks to its CAN bus through a Damiao serial bridge (`can_adapter=damiao`, the default) or a SocketCAN adapter (`can_adapter=socketcan`). See the [OpenArm page](./openarm) for more details on the SocketCAN adapter configuration.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
The bimanual leader and follower reuse the single-arm classes; each arm is
|
||||
configured through nested `left_arm_config.*` / `right_arm_config.*` arguments,
|
||||
so a bimanual reBot Arm 102 leader drives a bimanual B601-DM follower.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip>
|
||||
The leader and follower share the same joint names (`shoulder_pan,
|
||||
shoulder_lift, elbow_flex, wrist_flex, wrist_yaw, wrist_roll, gripper`), so
|
||||
leader actions map directly onto the follower.
|
||||
</Tip>
|
||||
|
||||
If the motion of a joint is reversed, flip its sign in the leader's `joint_directions` (the gripper also carries a scale to widen its range to the follower):
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.joint_directions='{"shoulder_pan":-1,"shoulder_lift":-1,"elbow_flex":1,"wrist_flex":1,"wrist_yaw":1,"wrist_roll":-1,"gripper":-6}'
|
||||
```
|
||||
|
||||
## Recording datasets
|
||||
|
||||
Swap `lerobot-teleoperate` for `lerobot-record` (with the same `--robot.*` / `--teleop.*` arguments, plus `--dataset.*`) to record demonstrations for training. See [Imitation Learning for Robots](./il_robots) for the full workflow.
|
||||
|
||||
For hardware assembly and wiring, see the [Seeed Studio reBot wiki](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/).
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
210
docs/source/tools.mdx
Normal file
210
docs/source/tools.mdx
Normal file
@@ -0,0 +1,210 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives.
|
||||
2. How the annotation pipeline produces tool-call atoms.
|
||||
3. How to add your own tool.
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
|
||||
not part of the catalog layer described here; today this layer ships
|
||||
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
The catalog above describes _what can be called_. The actual _call_ — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${current_plan}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
tool_calls_from: speech,
|
||||
}
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text _and_ the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which is not part of the catalog layer shipped
|
||||
> today — those modules don't yet exist in the tree. Step 1 alone is
|
||||
> enough to make the tool visible to the chat template via
|
||||
> `meta.tools` so the model can learn to _generate_ the call;
|
||||
> executing the call at inference requires the runtime layer.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk _before_ running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Short label for the saved image."
|
||||
}
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
|
||||
only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py`:
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
If you want to use a tool _without_ writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to _generate_ the call. Steps 2 and 3 are only
|
||||
needed to actually _execute_ it at inference.
|
||||
@@ -82,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"is_depth_map": false,
|
||||
"video.is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
@@ -97,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
|
||||
121
examples/annotations/run_hf_job.py
Normal file
121
examples/annotations/run_hf_job.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||
|
||||
Spawns one ``h200x4`` job that:
|
||||
|
||||
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||
2. boots four vllm servers (one per H200) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs the plan + vqa modules across the dataset in free-form
|
||||
mode — phase 0 (canonical vocabulary discovery) is disabled so
|
||||
every episode's subtasks + memory are generated independently;
|
||||
interjections is also disabled, which short-circuits the
|
||||
plan_update phase that depends on it,
|
||||
4. uploads the annotated dataset to ``--dest_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` below to point at your own dataset / target hub repo.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=pepijn223/robocasa_smoke_2atomic_v3 "
|
||||
"--dest_repo_id=pepijn223/robocasa_smoke_2atomic_v3_annotated "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||
"--vlm.parallel_servers=4 "
|
||||
"--vlm.num_gpus=4 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
|
||||
# 4× the context (32768 → 131072) so long episodes at 1 Hz fit even
|
||||
# at full Qwen vision resolution: 90 frames @ ~700 vision tokens/frame
|
||||
# ≈ 63 k tokens, comfortably under 131 k. On 1× H200 (144 GB) the
|
||||
# 35B-FP8 model leaves plenty of room for the bigger KV cache.
|
||||
"--tensor-parallel-size 1 --max-model-len 131072 "
|
||||
'--gpu-memory-utilization 0.85 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=256 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
# Low temperature for VQA: bbox + keypoint are coordinate-regression
|
||||
# tasks where sampling noise directly degrades localization
|
||||
# (overlapping boxes, drifted points). 0.2 keeps the model decisive
|
||||
# while still letting question/label phrasing vary across frames.
|
||||
"--vlm.temperature=0.2 "
|
||||
"--executor.episode_parallelism=64 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
# Whole-scene agentview is the right choice for subtask reasoning +
|
||||
# VQA on robocasa: the wrist (``robot0_eye_in_hand``) usually only
|
||||
# sees the gripper + nearby object, which hurts "what is happening
|
||||
# in this episode" decomposition. Override per-dataset if your
|
||||
# cameras are named differently (inspect ``meta/info.json``).
|
||||
"--vlm.camera_key=observation.images.robot0_agentview_left "
|
||||
# Phase 0 — canonical vocabulary discovery DISABLED. This dataset's
|
||||
# episodes span heterogeneous tasks/scenes, so a single shared
|
||||
# subtask + memory vocabulary would be too narrow — each episode
|
||||
# generates its subtasks + memory free-form instead.
|
||||
"--vocabulary.enabled=false "
|
||||
# Phase 1 — plan module (subtasks + plan + memory + task_aug).
|
||||
"--plan.enabled=true "
|
||||
"--plan.frames_per_second=1.0 "
|
||||
"--plan.use_video_url=true "
|
||||
"--plan.use_video_url_fps=1.0 "
|
||||
# Force coarse, composite subtasks (``pick up X`` = approach + grasp
|
||||
# + lift in one span, not three). 3 s is large enough to host a
|
||||
# full grasp-or-place composite at typical 20 fps robocasa speeds;
|
||||
# any candidate span shorter than this gets merged into a neighbour
|
||||
# by the prompt's authoring rules (see module_1_subtasks.txt).
|
||||
"--plan.min_subtask_seconds=3.0 "
|
||||
# Cap so the VLM can't drift into micro-segmentation. Combined with
|
||||
# the composite-action rules in the prompt, this targets ~3-6
|
||||
# meaningful spans per episode for typical pick-and-place demos.
|
||||
"--plan.plan_max_steps=9 "
|
||||
# ``off`` keeps the dataset's canonical ``record.episode_task`` as-is
|
||||
# — no per-episode VLM "what is this video about" call. Switch to
|
||||
# ``if_short`` (default) only if some episodes have placeholder /
|
||||
# missing canonical tasks; ``always`` overrides every episode's task.
|
||||
"--plan.derive_task_from_video=off "
|
||||
# 0 disables the task_aug pass entirely (see PlanConfig.n_task_rephrasings
|
||||
# docstring) — no per-episode paraphrase generation, no task_aug rows.
|
||||
"--plan.n_task_rephrasings=0 "
|
||||
# Phase 2 — interjections OFF (also skips phase 3 plan_update,
|
||||
# see executor.py:_run_plan_update_phase guard).
|
||||
"--interjections.enabled=false "
|
||||
# Phase 4 — general VQA. K=1 keeps each VQA answer on its own
|
||||
# emission frame (no temporal smear); see VqaConfig.K docstring.
|
||||
# 3 Hz cadence: at 20 fps source, that's a VQA tick every ~7 frames.
|
||||
# NOTE: VQA emits per-camera, so for robocasa (3 cameras) each tick
|
||||
# produces 3 (user, assistant) row pairs — total call volume ~= 3 *
|
||||
# 3 Hz * mean_episode_seconds * n_episodes.
|
||||
"--vqa.enabled=true "
|
||||
"--vqa.K=1 "
|
||||
"--vqa.vqa_emission_hz=3.0"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200x4",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="24h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
74
examples/benchmark/bench_pi052_kernels.slurm
Normal file
74
examples/benchmark/bench_pi052_kernels.slurm
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-kernels
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=01:30:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_kernels_%j.out
|
||||
|
||||
# HF kernels exploration via Liger's apply_liger_kernel_to_paligemma.
|
||||
# Baseline (SDPA, no kernels) vs. per-subkernel ablations vs. all-on.
|
||||
# Same harness as bench_pi052_step.py — only the --kernels flag varies
|
||||
# across runs so any delta is attributable to the patched op(s).
|
||||
#
|
||||
# Subkernels exercised: rope, rms_norm, geglu, layer_norm.
|
||||
# Skipped: cross_entropy / fused_linear_cross_entropy — pi052 calls
|
||||
# F.cross_entropy directly and bypasses PaliGemma's forward, so those
|
||||
# patches wouldn't fire without model-code changes (separate PR).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
# /fsx triton cache is shared across nodes with different glibc versions
|
||||
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
|
||||
# a node-local cache per job to side-step that.
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
ldd --version | head -1
|
||||
|
||||
# Liger isn't in our standard env yet — install on the compute node so
|
||||
# the slurm log captures the exact version that produced the numbers.
|
||||
python -m pip install -q --upgrade 'liger-kernel'
|
||||
python - <<'PY' || true
|
||||
from importlib.metadata import version, PackageNotFoundError
|
||||
try:
|
||||
print("liger-kernel", version("liger-kernel"))
|
||||
except PackageNotFoundError:
|
||||
print("liger-kernel: not importable")
|
||||
import liger_kernel.transformers as t
|
||||
print("apply_liger_kernel_to_paligemma:", hasattr(t, "apply_liger_kernel_to_paligemma"))
|
||||
PY
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# -- Baseline (no kernels) at the BS we actually train at. --
|
||||
run --attn sdpa --batch-size 8 --kernels none
|
||||
run --attn sdpa --batch-size 16 --kernels none
|
||||
|
||||
# -- Per-subkernel ablations at BS=16 to isolate each contributor. --
|
||||
run --attn sdpa --batch-size 16 --kernels rms_norm
|
||||
run --attn sdpa --batch-size 16 --kernels geglu
|
||||
run --attn sdpa --batch-size 16 --kernels layer_norm
|
||||
run --attn sdpa --batch-size 16 --kernels rope
|
||||
|
||||
# -- All-on, both BS to compare against the matched baselines above. --
|
||||
run --attn sdpa --batch-size 8 --kernels all
|
||||
run --attn sdpa --batch-size 16 --kernels all
|
||||
|
||||
# -- Headroom check: does kernels-all let BS=24 fit (baseline OOMs near here)? --
|
||||
run --attn sdpa --batch-size 24 --kernels none
|
||||
run --attn sdpa --batch-size 24 --kernels all
|
||||
338
examples/benchmark/bench_pi052_step.py
Normal file
338
examples/benchmark/bench_pi052_step.py
Normal file
@@ -0,0 +1,338 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Benchmark ``PI052Policy.forward + backward`` on a single GPU.
|
||||
|
||||
Compares the new SDPA attention path against the eager baseline by
|
||||
monkeypatching ``sdpa_attention_forward`` before the first model
|
||||
forward — so both runs share identical Q/K/V plumbing and only the
|
||||
attention kernel differs. Reports steps/sec and peak GPU memory.
|
||||
|
||||
SLURM-only:
|
||||
|
||||
sbatch examples/benchmark/bench_pi052_step.slurm
|
||||
|
||||
Or one-off:
|
||||
|
||||
srun --partition=hopper-prod --qos=high --gpus=1 --time=15 \\
|
||||
python examples/benchmark/bench_pi052_step.py --attn sdpa --batch-size 8
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _maybe_patch_eager() -> None:
|
||||
"""Swap ``sdpa_attention_forward`` for the original eager forward.
|
||||
|
||||
Must be called BEFORE PI052Policy is instantiated — the layer
|
||||
compute functions resolve the symbol at call time (module-level
|
||||
lookup), so this patch covers both pi05 and pi052 KI paths."""
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi05 import modeling_pi05
|
||||
|
||||
modeling_pi05.sdpa_attention_forward = modeling_gemma.eager_attention_forward
|
||||
|
||||
|
||||
_LIGER_SUBKERNELS = ("rope", "rms_norm", "geglu", "layer_norm")
|
||||
|
||||
|
||||
def _maybe_patch_liger(spec: str) -> dict:
|
||||
"""Globally patch PaliGemma/Gemma/Siglip modules with Liger Triton kernels.
|
||||
|
||||
Must be called BEFORE PI052Policy is instantiated — Liger replaces
|
||||
classes inside ``transformers.models.{gemma,gemma2,siglip,paligemma}``,
|
||||
so any model built after the call picks up the fused forwards.
|
||||
|
||||
``spec`` is a comma-separated subset of {rope, rms_norm, geglu,
|
||||
layer_norm} (also ``all`` and ``none``). ``cross_entropy`` and
|
||||
``fused_linear_cross_entropy`` are intentionally skipped — pi052's
|
||||
losses use ``F.cross_entropy`` directly (not ``nn.CrossEntropyLoss``)
|
||||
and never traverse ``PaliGemmaForConditionalGeneration.forward``,
|
||||
so neither patch would fire without invasive model-code changes.
|
||||
"""
|
||||
enabled = dict.fromkeys(_LIGER_SUBKERNELS, False)
|
||||
if spec in ("", "none"):
|
||||
return enabled
|
||||
tokens = [t.strip() for t in spec.split(",") if t.strip()]
|
||||
if tokens == ["all"]:
|
||||
enabled = dict.fromkeys(_LIGER_SUBKERNELS, True)
|
||||
else:
|
||||
for t in tokens:
|
||||
if t not in enabled:
|
||||
raise SystemExit(f"Unknown liger subkernel: {t!r}. Choose from {_LIGER_SUBKERNELS} or 'all'.")
|
||||
enabled[t] = True
|
||||
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_paligemma
|
||||
|
||||
apply_liger_kernel_to_paligemma(
|
||||
rope=enabled["rope"],
|
||||
rms_norm=enabled["rms_norm"],
|
||||
geglu=enabled["geglu"],
|
||||
layer_norm=enabled["layer_norm"],
|
||||
cross_entropy=False,
|
||||
fused_linear_cross_entropy=False,
|
||||
)
|
||||
return enabled
|
||||
|
||||
|
||||
def _maybe_patch_flex() -> None:
|
||||
"""Swap ``sdpa_attention_forward`` for a FlexAttention-backed forward.
|
||||
|
||||
Experimental: builds a per-call ``score_mod`` from the additive
|
||||
mask and dispatches to a compiled ``flex_attention`` kernel.
|
||||
|
||||
Known issue on torch 2.7.1: dynamo errors out with
|
||||
``FlexAttentionHigherOrderVariable() has no type`` when the
|
||||
``score_mod`` closure captures a per-call bias tensor. A proper
|
||||
port needs ``create_block_mask(mask_mod, ...)`` plumbed at the
|
||||
PI05Pytorch.forward level so a BlockMask object can be passed
|
||||
down to the layer compute, not a per-call closure. Left as
|
||||
future work; keep this stub for benchmark experimentation."""
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
|
||||
from lerobot.policies.pi05 import modeling_pi05
|
||||
|
||||
compiled_flex = torch.compile(flex_attention, dynamic=True)
|
||||
|
||||
def flex_forward(module, query, key, value, attention_mask, scaling, dropout=0.0):
|
||||
n_rep = module.num_key_value_groups
|
||||
if n_rep > 1:
|
||||
key = key.repeat_interleave(n_rep, dim=1)
|
||||
value = value.repeat_interleave(n_rep, dim=1)
|
||||
|
||||
bias = attention_mask # (B, 1, Lq, Lk) additive
|
||||
|
||||
def score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[b, 0, q_idx, kv_idx]
|
||||
|
||||
attn_output = compiled_flex(query, key, value, score_mod=score_mod, scale=scaling)
|
||||
return attn_output.transpose(1, 2).contiguous(), None
|
||||
|
||||
modeling_pi05.sdpa_attention_forward = flex_forward
|
||||
|
||||
|
||||
def _build_policy(args, device: torch.device):
|
||||
"""Random-init PI052Policy at production-relevant shapes."""
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.pi052.configuration_pi052 import PI052Config
|
||||
from lerobot.policies.pi052.modeling_pi052 import PI052Policy
|
||||
|
||||
# Production has ``unfreeze_lm_head=True`` + ``text_loss_weight>0``,
|
||||
# which flips ``train_expert_only=False`` in __post_init__ and
|
||||
# makes the whole PaliGemma + Gemma-expert stack trainable. We
|
||||
# mirror that here so the optimizer-state count reflects reality;
|
||||
# the loss path still goes through ``PI05Policy.forward`` because
|
||||
# ``text_labels`` / FAST tokens are absent from the synthetic batch
|
||||
# (see ``PI052Policy.forward`` early-return).
|
||||
config = PI052Config(
|
||||
max_action_dim=args.action_dim,
|
||||
max_state_dim=args.state_dim,
|
||||
dtype=args.dtype,
|
||||
knowledge_insulation=args.knowledge_insulation,
|
||||
text_loss_weight=1e-3 if args.train_full else 0.0,
|
||||
flow_loss_weight=1.0,
|
||||
enable_fast_action_loss=False,
|
||||
unfreeze_lm_head=args.train_full,
|
||||
tokenizer_max_length=args.lang_tokens,
|
||||
device="cuda",
|
||||
compile_model=args.compile_model,
|
||||
compile_mode=args.compile_mode,
|
||||
)
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(args.state_dim,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(args.action_dim,)),
|
||||
}
|
||||
policy = PI052Policy(config)
|
||||
policy.to(device)
|
||||
if args.gradient_checkpointing:
|
||||
policy.model.gradient_checkpointing_enable()
|
||||
policy.train()
|
||||
return policy, config
|
||||
|
||||
|
||||
def _build_batch(args, config, device: torch.device) -> dict:
|
||||
"""Synthetic batch matching the training-loop input contract."""
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
B = args.batch_size
|
||||
L = args.lang_tokens
|
||||
return {
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 250000, (B, L), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(B, L, dtype=torch.bool, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(B, 3, 224, 224, device=device),
|
||||
"observation.images.base_0_rgb_padding_mask": torch.ones(B, dtype=torch.bool, device=device),
|
||||
"observation.state": torch.randn(B, args.state_dim, device=device),
|
||||
ACTION: torch.randn(B, config.chunk_size, args.action_dim, device=device),
|
||||
"action_is_pad": torch.zeros(B, config.chunk_size, dtype=torch.bool, device=device),
|
||||
"task": ["bench task"] * B,
|
||||
}
|
||||
|
||||
|
||||
def _step(policy, batch, optimizer=None) -> torch.Tensor:
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
if optimizer is not None:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
for p in policy.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad = None
|
||||
return loss.detach()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--attn", choices=["sdpa", "eager", "flex"], default="sdpa")
|
||||
parser.add_argument(
|
||||
"--kernels",
|
||||
default="none",
|
||||
help=(
|
||||
"Liger sub-kernels to enable, comma-separated. Choose from "
|
||||
f"{_LIGER_SUBKERNELS} or use 'all' / 'none' (default). Applied "
|
||||
"via apply_liger_kernel_to_paligemma() BEFORE model build."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
dest="compile_model",
|
||||
action="store_true",
|
||||
help="Set policy.config.compile_model=True (torch.compile the forward).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile-mode",
|
||||
default="default",
|
||||
help="torch.compile mode (default | reduce-overhead | max-autotune).",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--warmup", type=int, default=8)
|
||||
parser.add_argument("--steps", type=int, default=40)
|
||||
parser.add_argument("--lang-tokens", type=int, default=512)
|
||||
parser.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16")
|
||||
parser.add_argument("--action-dim", type=int, default=14)
|
||||
parser.add_argument("--state-dim", type=int, default=14)
|
||||
parser.add_argument("--knowledge-insulation", action="store_true", default=True)
|
||||
parser.add_argument(
|
||||
"--gradient-checkpointing",
|
||||
dest="gradient_checkpointing",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
choices=["none", "adamw", "adamw_fused"],
|
||||
default="adamw_fused",
|
||||
help=(
|
||||
"Whether to include an AdamW step in the timed iteration. "
|
||||
"'none' mirrors the fwd+bwd-only original bench; 'adamw' / "
|
||||
"'adamw_fused' add the realistic ~2x param-bytes optimizer "
|
||||
"state and ``optimizer.step()`` cost."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-full",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help=(
|
||||
"Mirror production: unfreeze the PaliGemma backbone (full "
|
||||
"~3B trainable params) instead of training only the 300M "
|
||||
"action expert."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("Benchmark requires CUDA; submit via slurm (srun/sbatch).")
|
||||
|
||||
if args.attn == "eager":
|
||||
_maybe_patch_eager()
|
||||
elif args.attn == "flex":
|
||||
_maybe_patch_flex()
|
||||
|
||||
liger_flags = _maybe_patch_liger(args.kernels)
|
||||
|
||||
device = torch.device("cuda")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
policy, config = _build_policy(args, device)
|
||||
batch = _build_batch(args, config, device)
|
||||
|
||||
optimizer = None
|
||||
trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
if args.optimizer != "none":
|
||||
trainable = [p for p in policy.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.AdamW(
|
||||
trainable, lr=5e-5, fused=(args.optimizer == "adamw_fused")
|
||||
)
|
||||
|
||||
for _ in range(args.warmup):
|
||||
_step(policy, batch, optimizer)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
starter = torch.cuda.Event(enable_timing=True)
|
||||
ender = torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
for _ in range(args.steps):
|
||||
_step(policy, batch, optimizer)
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
total_ms = starter.elapsed_time(ender)
|
||||
step_ms = total_ms / args.steps
|
||||
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
optim_gb = 0.0
|
||||
if optimizer is not None:
|
||||
for st in optimizer.state.values():
|
||||
for v in st.values():
|
||||
if torch.is_tensor(v):
|
||||
optim_gb += v.numel() * v.element_size() / (1024**3)
|
||||
|
||||
liger_on = ",".join(k for k, v in liger_flags.items() if v) or "none"
|
||||
name = (
|
||||
f"{args.attn:>5} | BS={args.batch_size} | L={args.lang_tokens} | "
|
||||
f"KI={args.knowledge_insulation} | GC={args.gradient_checkpointing} | "
|
||||
f"compile={args.compile_model} | liger={liger_on} | opt={args.optimizer} | dtype={args.dtype}"
|
||||
)
|
||||
print(
|
||||
f"{name}\n step_ms={step_ms:.1f} steps/sec={1000.0 / step_ms:.3f} "
|
||||
f"peak_mem={peak_gb:.2f} GiB optim_state={optim_gb:.2f} GiB "
|
||||
f"trainable_params={trainable_params / 1e9:.2f}B"
|
||||
)
|
||||
|
||||
del policy, batch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
36
examples/benchmark/bench_pi052_step.slurm
Normal file
36
examples/benchmark/bench_pi052_step.slurm
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-attn
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:30:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
|
||||
python -c "import torch; print('torch', torch.__version__, 'cuda', torch.version.cuda)"
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Attention parity benchmark — same shapes, different attention kernel.
|
||||
run --attn eager --batch-size 8
|
||||
run --attn sdpa --batch-size 8
|
||||
|
||||
# Headroom benchmark — does SDPA's memory cut allow a bigger micro-batch?
|
||||
run --attn sdpa --batch-size 12
|
||||
run --attn sdpa --batch-size 16
|
||||
run --attn sdpa --batch-size 24
|
||||
39
examples/benchmark/bench_pi052_step_v2.slurm
Normal file
39
examples/benchmark/bench_pi052_step_v2.slurm
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v2
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v2_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# A: GC ON — see if the selective-AC change (one less recompute level)
|
||||
# narrows the eager vs SDPA gap at BS=8.
|
||||
run --attn eager --batch-size 8
|
||||
run --attn sdpa --batch-size 8
|
||||
|
||||
# B: GC OFF — isolate the raw attention-kernel cost & memory delta.
|
||||
run --attn eager --batch-size 4 --no-gradient-checkpointing
|
||||
run --attn sdpa --batch-size 4 --no-gradient-checkpointing
|
||||
|
||||
# C: SDPA + GC headroom sweep — where does it OOM?
|
||||
run --attn sdpa --batch-size 16
|
||||
run --attn sdpa --batch-size 24
|
||||
run --attn sdpa --batch-size 32
|
||||
36
examples/benchmark/bench_pi052_step_v3.slurm
Normal file
36
examples/benchmark/bench_pi052_step_v3.slurm
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v3
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v3_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Compile sweep: does torch.compile + SDPA give a non-trivial boost on
|
||||
# top of the bare SDPA path?
|
||||
run --attn sdpa --batch-size 8 --compile
|
||||
run --attn sdpa --batch-size 16 --compile
|
||||
|
||||
# FlexAttention sweep (experimental): score_mod adds the additive bias
|
||||
# in-kernel; expect a long first-step compile, then SDPA-or-better steady
|
||||
# state.
|
||||
run --attn flex --batch-size 8
|
||||
run --attn flex --batch-size 16
|
||||
41
examples/benchmark/bench_pi052_step_v4.slurm
Normal file
41
examples/benchmark/bench_pi052_step_v4.slurm
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v4
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=01:00:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v4_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
# /fsx triton cache is shared across nodes with different glibc versions
|
||||
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
|
||||
# a node-local cache per job to side-step that.
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
ldd --version | head -1
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# compile path on top of SDPA + selective AC
|
||||
run --attn sdpa --batch-size 8 --compile
|
||||
run --attn sdpa --batch-size 16 --compile
|
||||
|
||||
# FlexAttention experimental
|
||||
run --attn flex --batch-size 8
|
||||
run --attn flex --batch-size 16
|
||||
33
examples/benchmark/bench_pi052_step_v5.slurm
Normal file
33
examples/benchmark/bench_pi052_step_v5.slurm
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v5
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v5_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# compile_mode=default (graph-only, no autotune) is the right knob with
|
||||
# gradient checkpointing — max-autotune in v4 was 2x slower than no-compile.
|
||||
run --attn sdpa --batch-size 8 --compile --compile-mode default
|
||||
run --attn sdpa --batch-size 16 --compile --compile-mode default
|
||||
run --attn sdpa --batch-size 8 --compile --compile-mode reduce-overhead
|
||||
31
examples/benchmark/bench_pi052_step_v6.slurm
Normal file
31
examples/benchmark/bench_pi052_step_v6.slurm
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v6-bs32
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:30:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v6_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# BS=32 with the production settings (SDPA + compile=default).
|
||||
run --attn sdpa --batch-size 32 --compile --compile-mode default
|
||||
39
examples/benchmark/bench_pi052_step_v7.slurm
Normal file
39
examples/benchmark/bench_pi052_step_v7.slurm
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v7-opt
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v7_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Realistic full-step memory: fwd + bwd + AdamW step. The original
|
||||
# sweep was fwd+bwd-only and undercounted memory by the optimizer-
|
||||
# state size (~2x param bytes for AdamW). This run confirms BS=16
|
||||
# and BS=32 still fit with the optimizer in residency.
|
||||
run --attn sdpa --batch-size 16 --compile --compile-mode default --optimizer adamw_fused
|
||||
run --attn sdpa --batch-size 32 --compile --compile-mode default --optimizer adamw_fused
|
||||
|
||||
# Without compile, in case the production cluster has compile issues.
|
||||
run --attn sdpa --batch-size 16 --optimizer adamw_fused
|
||||
run --attn sdpa --batch-size 32 --optimizer adamw_fused
|
||||
36
examples/benchmark/bench_pi052_step_v8.slurm
Normal file
36
examples/benchmark/bench_pi052_step_v8.slurm
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v8-bs40-dtype
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v8_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Confirm BS=40 fits on a single H100 with the optimizer in residency.
|
||||
run --attn sdpa --batch-size 40 --compile --compile-mode default --optimizer adamw_fused
|
||||
|
||||
# Dtype A/B at modest batch — fp32 needs ~2x the memory of bf16, so we
|
||||
# drop to BS=4 to keep both runs comparable instead of OOMing fp32.
|
||||
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype bfloat16
|
||||
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype float32
|
||||
29
examples/benchmark/fsdp_pi052.yaml
Normal file
29
examples/benchmark/fsdp_pi052.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_transformer_layer_cls_to_wrap: GemmaDecoderLayer,SiglipEncoderLayer
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
parquet_path = local_path / progress_file
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
logging.warning("%s not found", progress_file)
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -576,6 +585,7 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -585,6 +595,8 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -592,7 +604,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -600,9 +612,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -627,7 +639,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -658,6 +670,15 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -670,6 +691,7 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
1053
examples/port_datasets/slurm_build_robocasa_composite_seen.py
Normal file
1053
examples/port_datasets/slurm_build_robocasa_composite_seen.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -85,6 +85,11 @@ dependencies = [
|
||||
"termcolor>=2.4.0,<4.0.0",
|
||||
"tqdm>=4.66.0,<5.0.0",
|
||||
|
||||
# Training utilities
|
||||
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
|
||||
# pure-python dependency — preferred over a hand-rolled implementation.
|
||||
"ema-pytorch>=0.7.7,<1.0.0",
|
||||
|
||||
# Build tools (required by opencv-python-headless on some platforms)
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
@@ -95,7 +100,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -138,8 +143,11 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
sentencepiece-dep = ["sentencepiece>=0.2.0,<0.3.0"] # FAST action tokenizer backend (pi052, pi0_fast)
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -151,6 +159,8 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
|
||||
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
@@ -174,6 +184,9 @@ unitree_g1 = [
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
@@ -190,7 +203,7 @@ wallx = [
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]", "lerobot[sentencepiece-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
@@ -212,6 +225,26 @@ hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). vllm is the preferred backend
|
||||
# on Linux, with a transformers fallback elsewhere; openai is the default
|
||||
# backend and talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted endpoints). Distributed execution is
|
||||
# delegated to Hugging Face Jobs (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||
# are isolated so adding a new tool doesn't bloat the base install.
|
||||
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||
tools = [
|
||||
"pocket-tts>=1.0.0,<3.0.0",
|
||||
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -260,6 +293,7 @@ all = [
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[openarms]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[rebot]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[diffusion]",
|
||||
@@ -301,7 +335,10 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
|
||||
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
@@ -319,7 +356,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
47
scripts/build_robocasa_smoke.sh
Executable file
47
scripts/build_robocasa_smoke.sh
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
# Build a tiny RoboCasa smoke dataset (2 short atomic tasks, all episodes) for
|
||||
# fast end-to-end training validation before the real run.
|
||||
#
|
||||
# Defaults: target/human, OpenStandMixerHead + NavigateKitchen (~1k episodes,
|
||||
# ~131k frames, ~109 min @ 20 fps), 2 SLURM workers on hopper-cpu.
|
||||
#
|
||||
# Override via env: TASKS, REPO_ID, WORK_DIR, WORKERS, CPUS, PARTITION, LOCAL=1.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate lerobot
|
||||
|
||||
REPO_ID="${REPO_ID:-${HF_USER:?HF_USER is unset}/robocasa_smoke_2atomic_v3}"
|
||||
WORK_DIR="${WORK_DIR:-/fsx/${USER}/robocasa/datasets/v1.0}"
|
||||
ROBOCASA_ROOT="${ROBOCASA_ROOT:-/fsx/${USER}/robocasa}"
|
||||
LOGS_DIR="${LOGS_DIR:-/fsx/${USER}/logs/robocasa}"
|
||||
TASKS="${TASKS:-OpenStandMixerHead NavigateKitchen}"
|
||||
WORKERS="${WORKERS:-2}"
|
||||
CPUS="${CPUS:-8}"
|
||||
PARTITION="${PARTITION:-hopper-cpu}"
|
||||
LOCAL="${LOCAL:-0}"
|
||||
|
||||
ARGS=(
|
||||
examples/port_datasets/slurm_build_robocasa_composite_seen.py
|
||||
--repo-id="$REPO_ID"
|
||||
--work-dir="$WORK_DIR"
|
||||
--robocasa-root="$ROBOCASA_ROOT"
|
||||
--split=target --source=human
|
||||
--tasks $TASKS
|
||||
--workers="$WORKERS"
|
||||
--cpus-per-task="$CPUS"
|
||||
--partition="$PARTITION"
|
||||
--mem-per-cpu=4G
|
||||
--time=04:00:00
|
||||
--logs-dir="$LOGS_DIR"
|
||||
--job-name=port_robocasa_smoke
|
||||
)
|
||||
if [[ "$LOCAL" == "1" ]]; then
|
||||
ARGS+=(--slurm=0)
|
||||
fi
|
||||
|
||||
echo "Smoke dataset: $REPO_ID"
|
||||
echo "Tasks: $TASKS"
|
||||
python "${ARGS[@]}"
|
||||
15
src/lerobot/annotations/__init__.py
Normal file
15
src/lerobot/annotations/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
50
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
50
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .vocabulary import (
|
||||
VOCABULARY_FILENAME,
|
||||
Vocabulary,
|
||||
VocabularyDiscoveryModule,
|
||||
load_vocabulary,
|
||||
save_vocabulary,
|
||||
vocabulary_path,
|
||||
)
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"VOCABULARY_FILENAME",
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
"Vocabulary",
|
||||
"VocabularyDiscoveryModule",
|
||||
"load_vocabulary",
|
||||
"save_vocabulary",
|
||||
"vocabulary_path",
|
||||
]
|
||||
251
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
251
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
@@ -0,0 +1,251 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabularyConfig:
|
||||
"""Phase 0 — dataset-level canonical vocabulary discovery.
|
||||
|
||||
Watches the first ``sample_episodes`` episode videos and asks the VLM
|
||||
to derive a small canonical vocabulary (subtask labels + memory
|
||||
milestones) that every episode in the dataset will reuse. The VLM
|
||||
decides the count itself from what it sees in the clips — short
|
||||
pick-and-place demos get ~6 labels, longer multi-step recipes more.
|
||||
The output lands at ``meta/canonical_vocabulary.json`` and feeds
|
||||
phase 1's subtask + memory generation as both a prompt-side
|
||||
constraint and a post-VLM validation gate.
|
||||
|
||||
Why this exists: free-form LLM rephrasing per episode produces near-
|
||||
unique subtask strings, which makes the downstream low-level policy's
|
||||
conditioning effectively noise — at inference the policy generates a
|
||||
*new* paraphrase the action expert has never seen and produces tiny
|
||||
cautious actions. Forcing every episode onto the same small set of
|
||||
canonical strings gives the action expert dense supervision per
|
||||
string and a small target distribution to learn against.
|
||||
|
||||
Set ``enabled=False`` to fall back to free-form generation (original
|
||||
behaviour). ``reuse_existing=True`` keeps a hand-edited vocabulary
|
||||
file from being clobbered on re-runs.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
sample_episodes: int = 3
|
||||
max_video_frames_per_episode: int = 32
|
||||
# When True (default), an existing meta/canonical_vocabulary.json is
|
||||
# loaded as-is and no VLM call is made — lets operators hand-edit the
|
||||
# file. Set False to always rediscover from the sample episodes.
|
||||
reuse_existing: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: plan + subtasks + memory + task augmentation.
|
||||
|
||||
The ``plan`` module attaches the whole episode as one Qwen-VL video
|
||||
block; ``max_video_frames`` only caps the frames packed in (a
|
||||
model-capacity bound, not an annotation-logic knob).
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Number of ``task_aug`` rephrasings emitted at ``t=0``. The renderer's
|
||||
# ``${task}`` binding rotates among them per ``sample_idx``. ``0`` disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# When to derive the task from the video instead of using
|
||||
# ``record.episode_task``: ``off``, ``if_short`` (short / placeholder /
|
||||
# missing canonical task), or ``always``. The derived task replaces the
|
||||
# canonical one for every ``plan``-module prompt; ``meta/tasks.parquet``
|
||||
# is never modified.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# Frame sampling for the subtask-decomposition prompt.
|
||||
frames_per_second: float = 1.0
|
||||
max_video_frames: int = 128
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# When True (and backend supports it, e.g. ``openai``), the ``plan``
|
||||
# module sends a ``video_url`` block pointing at a per-episode mp4
|
||||
# subclip and lets the server sample frames at ``use_video_url_fps``.
|
||||
use_video_url: bool = False
|
||||
use_video_url_fps: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each interjection emits a paired ``(interjection, speech)`` event row
|
||||
# and triggers a ``plan`` refresh at the same timestamp via the
|
||||
# ``plan`` module.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Visual context attached to the interjection prompt: a short window
|
||||
# of frames centered on the chosen timestamp so the VLM sees the
|
||||
# ongoing motion rather than a single frozen frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""How many *consecutive* frames each emission tick anchors a VQA pair
|
||||
to. The VLM grounds its answer (bbox / keypoint coordinates, count, …)
|
||||
against the *first* anchored frame's image, so anchoring K>1 frames
|
||||
copies that same answer onto later frames where the scene has already
|
||||
moved — stale labels. Default ``1``: a VQA pair lands on exactly its
|
||||
emission frame, no temporal smear. Raise it only to trade label
|
||||
precision for more (noisier) VQA frames."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests).
|
||||
# ``openai`` talks to a local OpenAI-compatible server; the CLI
|
||||
# auto-spawns one when ``auto_serve=True``.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-35B-A3B-FP8"
|
||||
|
||||
# OpenAI-compatible server endpoint; ``EMPTY`` works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# When True with ``backend=openai``, the CLI probes ``api_base`` and
|
||||
# spawns a server if none answers (default: ``transformers serve``).
|
||||
# Set to False to fail fast when pointing at a remote endpoint.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command. ``{port}`` is substituted per replica
|
||||
# when ``parallel_servers > 1``.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Run multiple independent inference servers for round-robin client
|
||||
# routing (each pinned to a GPU via ``CUDA_VISIBLE_DEVICES`` and bound
|
||||
# to ``serve_port + i``). ``num_gpus=0`` means one GPU per replica.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
json_mode: bool = True
|
||||
batch_size: int = 4
|
||||
tensor_parallel_size: int = 1
|
||||
|
||||
# Fraction of GPU memory vllm allocates for weights + KV cache.
|
||||
gpu_memory_utilization: float = 0.9
|
||||
# Cap context length (None = model default). On 80 GB H100 a 30B BF16
|
||||
# model often needs <= 8192 to leave KV-cache headroom.
|
||||
max_model_len: int | None = None
|
||||
trust_remote_code: bool = False
|
||||
|
||||
# Override the camera stream used for keyframe attachment. None picks
|
||||
# the first ``observation.images.*`` key the dataset declares.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as ``extra_body.chat_template_kwargs`` on every chat call;
|
||||
# use to pass model-specific flags such as ``{"enable_thinking": false}``.
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotation/run_hf_job.py``); this config only controls
|
||||
intra-process episode concurrency.
|
||||
"""
|
||||
|
||||
# Episodes processed concurrently within each module phase. Each
|
||||
# in-flight episode dispatches 3-5 dependent VLM calls, so this is the
|
||||
# main knob for saturating ``parallel_servers`` and ``client_concurrency``.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate``.
|
||||
|
||||
The writer rewrites ``data/chunk-*/file-*.parquet`` in place. Multiple
|
||||
revisions of the same dataset live in separate copies.
|
||||
"""
|
||||
|
||||
# Hub dataset id. Used as the download source when ``root`` is unset,
|
||||
# and as the destination repo when ``push_to_hub`` is enabled and
|
||||
# ``dest_repo_id`` is unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Optional separate Hub dataset id to push the annotated result to. When
|
||||
# unset, ``push_to_hub`` uploads back to ``repo_id`` (annotate in place);
|
||||
# when set, the source ``repo_id`` is left untouched.
|
||||
dest_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/`` when unset.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend. When unset, the pipeline decodes with the
|
||||
# ffmpeg CLI: it decodes AV1 and runs each decode as an isolated child
|
||||
# process, which is both crash-safe and safe under the concurrent
|
||||
# decode the executor performs (torchcodec is not thread-safe and
|
||||
# SIGSEGVs there). Set to ``"torchcodec"`` or ``"pyav"`` to pin an
|
||||
# in-process decoder when its build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
|
||||
# When True, upload the annotated dataset to the Hugging Face Hub:
|
||||
# to ``dest_repo_id`` if set, otherwise back to ``repo_id``. One of
|
||||
# the two must be set for this to take effect.
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
325
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
325
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
@@ -0,0 +1,325 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor plans **seven phases** in the dependency order from the plan:
|
||||
|
||||
phase 0: vocabulary discovery — derive a small canonical vocabulary
|
||||
from the first few sample-episode videos (subtask labels +
|
||||
memory milestones) and persist it next to the dataset; the
|
||||
``plan`` module then constrains every per-episode generation
|
||||
to those strings, so the downstream policy sees a small,
|
||||
repeatable conditioning distribution
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
vocabulary: Any = None # VocabularyDiscoveryModule | None
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 0: vocabulary discovery. Mutates ``self.plan.vocabulary``
|
||||
# so subsequent per-episode plan calls see the canonical labels.
|
||||
phases.append(self._run_vocabulary_phase(records, root))
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote
|
||||
# (language columns advertised; canonical ``say`` tool registered for
|
||||
# PI052 / Pi0.5 / dataset-visualizer consumers via
|
||||
# ``LeRobotDatasetMetadata.tools``). Idempotent and additive: existing
|
||||
# user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_vocabulary_phase(
|
||||
self, records: list[EpisodeRecord], root: Path
|
||||
) -> PhaseResult:
|
||||
"""Discover (or load) the canonical vocabulary, wire it into ``self.plan``.
|
||||
|
||||
Returns a ``PhaseResult`` whose ``episodes_processed`` is the number
|
||||
of sample episodes consulted (0 when disabled or no VLM call was
|
||||
needed); ``episodes_skipped`` is always ``0`` because vocabulary is
|
||||
a once-per-dataset artifact, not a per-episode product.
|
||||
"""
|
||||
from .vocabulary import load_vocabulary, save_vocabulary # noqa: PLC0415
|
||||
|
||||
if self.vocabulary is None or not getattr(self.vocabulary, "enabled", False):
|
||||
print(
|
||||
"[annotate] phase=vocabulary skipped (module disabled or unset)",
|
||||
flush=True,
|
||||
)
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
existing = load_vocabulary(root)
|
||||
if existing is not None and self.config.vocabulary.reuse_existing:
|
||||
print(
|
||||
f"[annotate] phase=vocabulary reusing {root / 'meta' / 'canonical_vocabulary.json'} "
|
||||
f"({len(existing.subtasks)} subtask labels, "
|
||||
f"{len(existing.memory_milestones)} memory milestones)",
|
||||
flush=True,
|
||||
)
|
||||
self.plan.vocabulary = existing
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
sample_n = max(1, min(int(self.config.vocabulary.sample_episodes), len(records)))
|
||||
print(
|
||||
f"[annotate] phase=vocabulary discovering from {sample_n} sample episode(s)...",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
vocab = self.vocabulary.discover(records[:sample_n], existing=existing)
|
||||
if vocab is None:
|
||||
print(
|
||||
"[annotate] phase=vocabulary returned no vocabulary — "
|
||||
"plan module will fall back to free-form generation",
|
||||
flush=True,
|
||||
)
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
save_path = save_vocabulary(root, vocab)
|
||||
print(
|
||||
f"[annotate] phase=vocabulary wrote {save_path} "
|
||||
f"({len(vocab.subtasks)} subtask labels, "
|
||||
f"{len(vocab.memory_milestones)} memory milestones) in "
|
||||
f"{time.time() - t0:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
self.plan.vocabulary = vocab
|
||||
return PhaseResult(name="vocabulary", episodes_processed=sample_n, episodes_skipped=0)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(
|
||||
name="plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||
)
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
483
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
483
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
@@ -0,0 +1,483 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
|
||||
# concurrency- and crash-safe default for the pipeline's threaded
|
||||
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
|
||||
# decoder when the build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# ``camera_keys`` covers both image- and video-stored cameras and is
|
||||
# always defined on the metadata (``[]`` in the worst case), so it is
|
||||
# the single source we need here.
|
||||
keys = list(self._meta.camera_keys)
|
||||
# Last-resort fallback: if metadata didn't surface anything but the
|
||||
# caller explicitly named a camera (``--vlm.camera_key=...``), trust
|
||||
# them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips
|
||||
re-extract when the cached clip already exists. Re-encodes to
|
||||
H.264 (libx264) so the resulting mp4 is decodable by every
|
||||
downstream video processor — stream-copy would inherit the
|
||||
source codec (often AV1 in modern LeRobot datasets), which
|
||||
vllm's libav build cannot decode.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, timeout=300)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
|
||||
Returns one frame per requested timestamp, or ``[]`` if decoding
|
||||
failed wholesale — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
|
||||
# ThreadPoolExecutor and the in-process decoders are unsafe there:
|
||||
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
|
||||
# (a crash no try/except can catch), PyAV can likewise segfault on
|
||||
# AV1, and lerobot's ``pyav`` backend routes through the removed
|
||||
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
|
||||
# out per frame: each decode is an isolated child process, so it is
|
||||
# both crash-safe and concurrency-safe. ``video_backend`` can pin
|
||||
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
|
||||
# build is safe.
|
||||
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
|
||||
|
||||
exc: Exception | None = None
|
||||
for backend in chain:
|
||||
try:
|
||||
if backend == "ffmpeg":
|
||||
return _decode_frames_ffmpeg(video_path, shifted)
|
||||
if backend in ("pyav", "av"):
|
||||
return _decode_frames_av(video_path, shifted)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as e: # noqa: PERF203
|
||||
exc = e
|
||||
|
||||
# Every backend raised. Log loudly the first time so a silent
|
||||
# vqa-module no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = getattr(self, "_warned_decode_fail", False)
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s "
|
||||
"video_path=%s backends=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
chain,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
|
||||
|
||||
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
|
||||
piping a single PNG to stdout. Unlike the in-process decoders this
|
||||
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
|
||||
modern LeRobot datasets use) where torchcodec raises and PyAV can
|
||||
SIGSEGV, and a crash stays isolated to the child process — a non-zero
|
||||
exit is a catchable error, not a segfault of the whole job. Returns one
|
||||
``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import io # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
import numpy as np # noqa: PLC0415
|
||||
|
||||
frames: list[Any] = []
|
||||
for ts in timestamps:
|
||||
proc = subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-nostdin", "-loglevel", "error",
|
||||
"-ss", f"{max(ts, 0.0):.3f}",
|
||||
"-i", str(video_path),
|
||||
"-frames:v", "1",
|
||||
"-f", "image2pipe", "-vcodec", "png", "pipe:1",
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=120,
|
||||
)
|
||||
if not proc.stdout:
|
||||
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
|
||||
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
|
||||
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
|
||||
return frames
|
||||
|
||||
|
||||
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||
|
||||
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
|
||||
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
|
||||
fallback; this stays available behind ``video_backend="pyav"``. Returns
|
||||
one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# Seek to the keyframe at or before the first requested timestamp.
|
||||
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||
for idx, frame in enumerate(container.decode(stream)):
|
||||
ts = frame.time
|
||||
if ts is None:
|
||||
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||
loaded_ts.append(ts)
|
||||
loaded_frames.append(
|
||||
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||
)
|
||||
if ts >= last_ts:
|
||||
break
|
||||
if not loaded_frames:
|
||||
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||
ts_tensor = torch.tensor(loaded_ts)
|
||||
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/subtasks_vqa.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not getattr(self, "_warned_no_camera", False):
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
"""
|
||||
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(
|
||||
record, [frame_timestamp], camera_key=camera_key
|
||||
)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("module_2_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("module_2_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(0.0, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -0,0 +1,617 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..vocabulary import Vocabulary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
vocabulary: Vocabulary | None = None
|
||||
"""When set, the module constrains subtask + memory generation to the
|
||||
canonical strings in ``vocabulary``. Phase 0 (vocabulary discovery)
|
||||
populates this once per dataset; ``None`` falls back to free-form
|
||||
generation (original behaviour)."""
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Resolve the task that drives every other ``plan``-module prompt.
|
||||
# May be the canonical ``record.episode_task`` (default), or a fresh
|
||||
# description derived from the video when the canonical task is
|
||||
# empty / placeholder / forced-off (see PlanConfig.derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
|
||||
# message renderer rotates ``${task}`` deterministically through
|
||||
# them so the policy sees diverse phrasings during training.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
if self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
# Always include the effective task itself as the first variant
|
||||
# so the rotation is guaranteed to cover the source-of-truth
|
||||
# phrasing, not just synthetic alternatives.
|
||||
seen: set[str] = set()
|
||||
ordered = [effective_task, *rephrasings]
|
||||
for phrasing in ordered:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": key,
|
||||
"style": "task_aug",
|
||||
"timestamp": t0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary — including t=0 (start of
|
||||
# the first subtask). Because the plan is just a numbered list
|
||||
# of *still-todo* subtasks, re-emitting at each boundary makes
|
||||
# the active plan shrink as work progresses: at frame t the
|
||||
# rendered ``${plan}`` is the most recent emission, which
|
||||
# contains exactly the subtasks that started at or after the
|
||||
# current span. Saves the runtime from having to derive
|
||||
# "what's still left" at inference time.
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start
|
||||
prior_memory = ""
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers (factored out: every ``plan``-module prompt below follows
|
||||
# the same "build messages → single VLM call → pull a named field"
|
||||
# shape, only differing in field name + post-processing).
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
|
||||
"""User message combining the episode video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||
return (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections — subtask generation
|
||||
runs ~1 Hz at inference, but plan re-emission is event-driven.
|
||||
Now also forwards the interjection's own text into the prompt so
|
||||
the refreshed plan can actually reflect the user's correction
|
||||
(the previous version told the model "an interjection happened"
|
||||
without telling it what the user said).
|
||||
"""
|
||||
existing = staging.read("plan")
|
||||
# Pass the episode's last frame timestamp so the final subtask
|
||||
# span is closed (otherwise its ``end`` equals its ``start``,
|
||||
# zero duration, and the "current subtask at refresh_t" lookup
|
||||
# in ``_generate_plan`` misses any refresh that lands inside it).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
vocabulary_block=self._subtask_vocabulary_block(),
|
||||
)
|
||||
messages = self._video_message(record, prompt)
|
||||
spans = self._vlm_field(messages, "subtasks")
|
||||
# When a vocabulary is in force, do a single targeted retry if
|
||||
# any returned subtask is off-vocab — strict exact-match only,
|
||||
# no fuzzy snapping. The retry includes the offending strings
|
||||
# and the full canonical list so the VLM can correct itself.
|
||||
if self.vocabulary is not None and self.vocabulary.subtasks and spans:
|
||||
invalid = self._invalid_subtasks(spans)
|
||||
if invalid:
|
||||
logger.info(
|
||||
"episode %d: VLM emitted %d off-vocab subtask(s) (%s); retrying once",
|
||||
record.episode_index,
|
||||
len(invalid),
|
||||
invalid,
|
||||
)
|
||||
retry_msg = self._build_subtask_retry_message(messages, invalid)
|
||||
retried = self._vlm_field(retry_msg, "subtasks")
|
||||
if retried:
|
||||
spans = retried
|
||||
|
||||
if not spans:
|
||||
return []
|
||||
# clamp to [t0, t_last] and sort
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(t0, min(start, t_last))
|
||||
end = max(t0, min(end, t_last))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
text = self._canonicalize_subtask(text)
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
cleaned = self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
if self.vocabulary is not None and self.vocabulary.subtasks and not cleaned:
|
||||
logger.warning(
|
||||
"episode %d: every VLM subtask was off-vocab even after retry — "
|
||||
"episode left empty (extend meta/canonical_vocabulary.json to "
|
||||
"cover the missing phase)",
|
||||
record.episode_index,
|
||||
)
|
||||
return cleaned
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Canonical-vocabulary helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _subtask_vocabulary_block(self) -> str:
|
||||
"""Bullet-list of canonical subtasks the VLM must pick from.
|
||||
|
||||
Returns an empty string when no vocabulary is configured —
|
||||
``module_1_subtasks.txt`` then falls back to its free-form
|
||||
rules (original behaviour).
|
||||
"""
|
||||
if self.vocabulary is None or not self.vocabulary.subtasks:
|
||||
return ""
|
||||
bullets = "\n".join(f"- {s}" for s in self.vocabulary.subtasks)
|
||||
return (
|
||||
"You MUST choose each subtask label verbatim from this canonical "
|
||||
"vocabulary — pick the closest match for each phase of the demo, "
|
||||
"and reuse the SAME string every time that phase recurs. The "
|
||||
"low-level policy is conditioned on these exact strings; any "
|
||||
"novel paraphrase you invent will make its conditioning OOD.\n"
|
||||
"Canonical subtask labels:\n"
|
||||
f"{bullets}\n\n"
|
||||
)
|
||||
|
||||
def _memory_vocabulary_block(self) -> str:
|
||||
"""Bullet-list of canonical memory milestones the VLM must pick from."""
|
||||
if self.vocabulary is None or not self.vocabulary.memory_milestones:
|
||||
return ""
|
||||
bullets = "\n".join(f"- {m}" for m in self.vocabulary.memory_milestones)
|
||||
return (
|
||||
"Compose the memory by picking ONLY from this canonical milestone "
|
||||
"list — append a milestone (or rewrite the running memory to "
|
||||
"compress past ones) using these exact phrases. Do not invent new "
|
||||
"wording: every paraphrase weakens the downstream conditioning.\n"
|
||||
"Canonical memory milestones:\n"
|
||||
f"{bullets}\n\n"
|
||||
)
|
||||
|
||||
_NORMALIZE_STRIP_TOKENS: frozenset[str] = frozenset({"the", "a", "an"})
|
||||
|
||||
def _canonicalize_subtask(self, text: str) -> str:
|
||||
"""Validate ``text`` against the canonical vocabulary; no fuzzy snap.
|
||||
|
||||
Without a vocabulary, the original text passes through. With a
|
||||
vocabulary, accept the span only if its normalised form (lower-
|
||||
cased, articles stripped, whitespace collapsed) matches a
|
||||
canonical entry exactly — the canonical wording is returned so
|
||||
the supervised string is byte-identical across episodes.
|
||||
|
||||
Off-vocab spans are dropped (empty string). Upstream
|
||||
``_generate_subtasks`` triggers a targeted retry before reaching
|
||||
the drop path; this function never snaps or warps a span into
|
||||
a different label.
|
||||
"""
|
||||
if self.vocabulary is None or not self.vocabulary.subtasks:
|
||||
return text.strip()
|
||||
normalised = self._normalize(text)
|
||||
if not normalised:
|
||||
return ""
|
||||
for candidate in self.vocabulary.subtasks:
|
||||
if self._normalize(candidate) == normalised:
|
||||
return candidate
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _normalize(cls, text: str) -> str:
|
||||
"""Lowercase, strip articles, collapse whitespace, drop punctuation."""
|
||||
words = [
|
||||
w.strip(".,:;\"'!?()")
|
||||
for w in text.lower().replace(",", " ").split()
|
||||
]
|
||||
return " ".join(w for w in words if w and w not in cls._NORMALIZE_STRIP_TOKENS)
|
||||
|
||||
def _invalid_subtasks(self, spans: list[dict[str, Any]]) -> list[str]:
|
||||
"""Return the unique off-vocab subtask strings the VLM produced."""
|
||||
seen: list[str] = []
|
||||
for span in spans:
|
||||
text = str((span or {}).get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if self._canonicalize_subtask(text):
|
||||
continue
|
||||
if text not in seen:
|
||||
seen.append(text)
|
||||
return seen
|
||||
|
||||
def _build_subtask_retry_message(
|
||||
self, original_messages: list[dict[str, Any]], invalid: list[str]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Compose a one-shot correction prompt naming the off-vocab strings."""
|
||||
assert self.vocabulary is not None
|
||||
canonical = "\n".join(f"- {s}" for s in self.vocabulary.subtasks)
|
||||
invalid_list = "\n".join(f"- {s!r}" for s in invalid)
|
||||
correction = (
|
||||
"Your previous response included subtask labels that are NOT in "
|
||||
"the canonical vocabulary:\n"
|
||||
f"{invalid_list}\n\n"
|
||||
"Re-emit the same segmentation (same number of spans, same start/end "
|
||||
"timestamps where they were valid) but replace every off-vocab "
|
||||
"label with the EXACT canonical string for that phase, copied "
|
||||
"verbatim from this list:\n"
|
||||
f"{canonical}\n\n"
|
||||
"Strict rules:\n"
|
||||
"- Output strings must be byte-for-byte identical to entries above.\n"
|
||||
"- No articles, no adverbs, no extra words.\n"
|
||||
"- If a phase truly has no canonical match, omit that span entirely.\n"
|
||||
"Return the same JSON shape as before."
|
||||
)
|
||||
# Append the correction as an additional user turn; the model
|
||||
# sees the original prompt + its prior output is implied by the
|
||||
# conversation context (the VLM client is stateless, so we
|
||||
# re-send the original content plus this correction).
|
||||
retry_messages = [
|
||||
{
|
||||
"role": m.get("role", "user"),
|
||||
"content": (
|
||||
m.get("content")
|
||||
if isinstance(m.get("content"), str)
|
||||
else list(m.get("content") or [])
|
||||
),
|
||||
}
|
||||
for m in original_messages
|
||||
]
|
||||
retry_messages.append({"role": "user", "content": correction})
|
||||
return retry_messages
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
Previously this called the VLM with a prompt that asked it to
|
||||
compress the subtasks into a "compact hierarchical plan". That
|
||||
produced longer-than-necessary plans, cost an extra VLM round-trip
|
||||
per episode (plus one per interjection on refresh), and could
|
||||
diverge from the actual subtask sequence the model is going to
|
||||
execute. Replacing it with a plain summarisation keeps the plan
|
||||
tightly aligned with the upcoming subtasks and removes the VLM
|
||||
call entirely.
|
||||
|
||||
Layout — short imperative fragments prefixed by "N. ":
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
...
|
||||
|
||||
On a refresh at ``refresh_t`` (called from ``run_plan_updates``
|
||||
on interjection events, and from ``run_episode`` at every subtask
|
||||
boundary), only subtasks whose start is at or after ``refresh_t``
|
||||
are included — the plan shrinks as work progresses, so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s
|
||||
for s in subtask_spans
|
||||
if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(
|
||||
f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1)
|
||||
)
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("module_1_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
vocabulary_block=self._memory_vocabulary_block(),
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -0,0 +1,53 @@
|
||||
You are inspecting {n_episodes} sample episode video(s) from a teleoperated
|
||||
robot dataset. Every episode in the dataset performs the SAME task; the
|
||||
user originally asked: "{episode_task}".
|
||||
|
||||
Watch all the clips and produce a SHORT canonical vocabulary that every
|
||||
episode in this dataset will reuse. The downstream low-level policy is
|
||||
conditioned on these strings — duplicate phrasings (e.g. "grasp blue
|
||||
cube" vs "pick up the blue cube") would destroy the conditioning, so
|
||||
pick one wording per concept and reuse it everywhere.
|
||||
|
||||
Decide how many entries each list needs YOURSELF based on what you see —
|
||||
the smallest set that still covers every recurring phase in the demos.
|
||||
A simple two-object pick-and-place might need ~6 subtask labels and 2
|
||||
memory milestones; a long multi-step recipe needs more. Err on the side
|
||||
of FEWER — extra entries that don't recur across episodes weaken the
|
||||
conditioning.
|
||||
|
||||
You output two lists:
|
||||
|
||||
1. `subtasks`: imperative, telegraphic commands the robot can execute.
|
||||
- Verb-first. Drop articles, adverbs, qualifiers.
|
||||
- Consistent object nouns (if the task says "cube", every subtask says
|
||||
"cube" — never "block" / "object").
|
||||
- Atomic — one skill per subtask (gripper-open events, contact, regrasps,
|
||||
transitions all become cut points).
|
||||
- Each label must recur across the demos. If you see a motion only
|
||||
once across all sample clips, it probably isn't a canonical phase.
|
||||
- Good: "move to blue cube", "grasp blue cube", "lift blue cube",
|
||||
"place blue cube in box", "release blue cube", "retract arm".
|
||||
- Bad: "the robot arm moves towards the blue cube" (third person,
|
||||
too long), "carefully pick up the cube" (adverb, article),
|
||||
"carrying the yellow cube over the green basket" (gerund — should
|
||||
be imperative "transport yellow cube to green basket").
|
||||
|
||||
2. `memory_milestones`: first-person past-tense sentences the running
|
||||
memory composes from. Each subtask phase that produces a lasting
|
||||
change should have a milestone; transient motions (move, retract)
|
||||
should NOT.
|
||||
- First person, past tense. Start with "I".
|
||||
- One sentence. Functional outcome only — no grasp / motion detail.
|
||||
- Good: "I picked up the blue cube.", "I placed the blue cube in
|
||||
the green box.", "I wiped the counter."
|
||||
- Bad: "The robot arm grasped the blue cube." (third person),
|
||||
"I carefully grasped the blue cube with the parallel gripper."
|
||||
(irrelevant detail), "I moved towards the blue cube." (transient
|
||||
motion — should be omitted, not memorialised).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": ["<verb phrase>", ...],
|
||||
"memory_milestones": ["I <past-tense sentence>.", ...]
|
||||
}}
|
||||
@@ -0,0 +1,36 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
{vocabulary_block}Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -0,0 +1,80 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{vocabulary_block}Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,17 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -0,0 +1,12 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -0,0 +1,46 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -0,0 +1,57 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
QUALITY BAR — read before answering:
|
||||
|
||||
- Only label objects you are highly confident about. If you are not
|
||||
sure what an object is, do NOT include it. A short, certain answer
|
||||
beats a long, speculative one.
|
||||
- For coordinate-grounded answers (bbox, keypoint) only emit a label
|
||||
when you can localize the object *tightly and precisely*. If the
|
||||
object is occluded, ambiguous, off-frame, or you can't pin its
|
||||
extent, return an empty detections list / pick a different object
|
||||
rather than guessing.
|
||||
- Prefer task-relevant objects (the thing the robot is manipulating
|
||||
or interacting with) over background clutter.
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
Pixel coordinates (x_min, y_min, x_max, y_max). Emit
|
||||
AT MOST 3 detections, and *only* the highest-confidence
|
||||
ones — 1 tight, certain detection is preferred over 3
|
||||
loose ones. Each box must be tight (no >10% padding
|
||||
around the object) and the label must be specific
|
||||
("red mug" not "object"). Return an empty list if no
|
||||
object meets the bar.
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
Pick ONE high-confidence, precisely-localizable point
|
||||
(e.g. a graspable handle, a button center, the gripper
|
||||
tip). The point must land within a few pixels of the
|
||||
feature. Do not emit a coarse "somewhere on the object"
|
||||
point — pick a different question type if no such
|
||||
point exists in this frame.
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
|
||||
|
||||
def gather_data_paths(root: Path) -> list[Path]:
|
||||
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
|
||||
return sorted((root / "data").rglob("*.parquet"))
|
||||
|
||||
|
||||
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
|
||||
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
|
||||
table = pq.read_table(path, columns=["episode_index"])
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
out: dict[int, tuple[int, int]] = {}
|
||||
cur_ep: int | None = None
|
||||
start = 0
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start = i
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
out[cur_ep] = (start, i - start)
|
||||
cur_ep = ep
|
||||
start = i
|
||||
if cur_ep is not None:
|
||||
out[cur_ep] = (start, len(episode_col) - start)
|
||||
return out
|
||||
|
||||
|
||||
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
|
||||
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
|
||||
n = record.row_count
|
||||
if k <= 0 or n == 0:
|
||||
return []
|
||||
if k >= n:
|
||||
return list(range(n))
|
||||
step = (n - 1) / (k - 1) if k > 1 else 0.0
|
||||
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
|
||||
|
||||
|
||||
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
|
||||
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
|
||||
for path in gather_data_paths(root):
|
||||
offsets = episode_offsets_per_path(path)
|
||||
if episode_index in offsets:
|
||||
start, count = offsets[episode_index]
|
||||
return path, start, count
|
||||
return None
|
||||
|
||||
|
||||
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
|
||||
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
|
||||
found = lookup_data_path(root, episode_index)
|
||||
if found is None:
|
||||
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
|
||||
path, start, count = found
|
||||
table = pq.read_table(path, columns=["timestamp"])
|
||||
timestamps = table.column("timestamp").to_pylist()[start : start + count]
|
||||
return path, [float(t) for t in timestamps]
|
||||
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
|
||||
|
||||
def iter_staged_episodes(root: Path) -> Iterator[int]:
|
||||
"""Yield episode indices for which any staging artifact exists."""
|
||||
if not root.exists():
|
||||
return
|
||||
for child in sorted(root.iterdir()):
|
||||
if child.is_dir() and child.name.startswith("episode_"):
|
||||
try:
|
||||
yield int(child.name.removeprefix("episode_"))
|
||||
except ValueError:
|
||||
continue
|
||||
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(
|
||||
row, report, record.episode_index, self.dataset_camera_keys
|
||||
)
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||
)
|
||||
return
|
||||
if (
|
||||
is_view_dependent_style(style)
|
||||
and dataset_camera_keys
|
||||
and camera not in dataset_camera_keys
|
||||
):
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
@@ -0,0 +1,703 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output (``json_mode=True`` enables guided decoding when the
|
||||
backend supports it),
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client per the configured backend.
|
||||
|
||||
For ``stub``, callers should construct :class:`StubVlmClient` directly with
|
||||
a responder callable. ``stub`` here is rejected to make accidental misuse
|
||||
obvious.
|
||||
"""
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend == "vllm":
|
||||
return _make_vllm_client(config)
|
||||
if config.backend == "transformers":
|
||||
return _make_transformers_client(config)
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||
) from exc
|
||||
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
|
||||
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||
# convolution kernels — slower but functional.
|
||||
if os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
_torch.backends.cudnn.enabled = False
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"tensor_parallel_size": config.tensor_parallel_size,
|
||||
"gpu_memory_utilization": config.gpu_memory_utilization,
|
||||
"trust_remote_code": config.trust_remote_code,
|
||||
}
|
||||
if config.max_model_len is not None:
|
||||
llm_kwargs["max_model_len"] = config.max_model_len
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
# ``guided_decoding`` would speed up parsing but its API differs across
|
||||
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
|
||||
# wrapper already has a one-retry JSON-recovery path, so we skip it.
|
||||
params = SamplingParams(max_tokens=max_tok, temperature=temp)
|
||||
# ``llm.chat`` handles chat-template application + multimodal input
|
||||
# extraction (image/video blocks) internally, which ``llm.generate``
|
||||
# does not.
|
||||
outputs = llm.chat([list(m) for m in batch], params)
|
||||
return [o.outputs[0].text for o in outputs]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
import transformers # type: ignore[import-not-found]
|
||||
from transformers import AutoProcessor # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
||||
auto_cls = getattr(transformers, "AutoModelForImageTextToText", None) or getattr(
|
||||
transformers, "AutoModelForVision2Seq", None
|
||||
)
|
||||
if auto_cls is None:
|
||||
raise ImportError(
|
||||
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
|
||||
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
|
||||
"for VL models."
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(config.model_id, trust_remote_code=config.trust_remote_code)
|
||||
use_accelerate = os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
|
||||
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
|
||||
if use_accelerate:
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype=_torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
model = model.to("cuda")
|
||||
model.eval()
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
outs: list[str] = []
|
||||
for messages in batch:
|
||||
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
inputs = processor(text=[text], return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
gen = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tok,
|
||||
temperature=temp,
|
||||
do_sample=temp > 0.0,
|
||||
)
|
||||
decoded = processor.batch_decode(
|
||||
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
|
||||
)[0]
|
||||
outs.append(decoded)
|
||||
return outs
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
|
||||
"""Pass-through hook used by the vllm backend.
|
||||
|
||||
vllm exposes its own multimodal entry points that vary by version; for the
|
||||
base flow we simply forward the raw message list and let the caller's
|
||||
custom backend handle templating. Real deployments override this.
|
||||
"""
|
||||
return list(messages)
|
||||
222
src/lerobot/annotations/steerable_pipeline/vocabulary.py
Normal file
222
src/lerobot/annotations/steerable_pipeline/vocabulary.py
Normal file
@@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dataset-level canonical vocabulary discovery (Phase 0).
|
||||
|
||||
The downstream consumer of these annotations is a low-level action expert
|
||||
conditioned on the ``subtask`` string. Free-form per-episode LLM rephrasing
|
||||
gives near-unique strings per occurrence, which collapses the action
|
||||
expert's conditioning to noise and makes runtime subtask-paraphrase drift
|
||||
catastrophic. The Hi-Robot / π0.6-MEM recipe ships a small canonical
|
||||
vocabulary per environment (~10 strings) that every episode reuses; this
|
||||
module derives that vocabulary automatically from the first few episode
|
||||
videos and persists it next to the dataset.
|
||||
|
||||
Pipeline-level flow:
|
||||
|
||||
Phase 0 (here): watch N sample episodes → produce vocabulary.json
|
||||
Phase 1 (plan module): reuse vocabulary on every episode, both as
|
||||
prompt-side constraint *and* post-VLM validation
|
||||
|
||||
The vocabulary is JSON, lives at ``<root>/meta/canonical_vocabulary.json``,
|
||||
and is human-inspectable / hand-editable — if the discovered set is wrong,
|
||||
operators edit the file and re-run the pipeline without phase 0.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import VocabularyConfig
|
||||
from .frames import FrameProvider, null_provider, to_video_block
|
||||
from .prompts import load as load_prompt
|
||||
from .reader import EpisodeRecord
|
||||
from .vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCABULARY_FILENAME = "canonical_vocabulary.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vocabulary:
|
||||
"""Canonical phrasings shared across every episode of one dataset.
|
||||
|
||||
Both lists are strict: per-episode subtask + memory generation pick
|
||||
from these strings only; the downstream policy then has a small,
|
||||
repeatable target distribution to learn instead of thousands of
|
||||
LLM paraphrases.
|
||||
"""
|
||||
|
||||
subtasks: tuple[str, ...]
|
||||
"""Imperative subtask labels — what the low-level policy is conditioned
|
||||
on. Verb-first, telegraphic, consistent object nouns. Example:
|
||||
``("move to blue cube", "grasp blue cube", "lift blue cube",
|
||||
"place blue cube in box", "retract arm")``.
|
||||
"""
|
||||
|
||||
memory_milestones: tuple[str, ...]
|
||||
"""First-person past-tense milestone sentences — building blocks for
|
||||
the running memory string. Example: ``("I picked up the blue cube.",
|
||||
"I placed the blue cube in the green box.")``. Each milestone maps
|
||||
1:1 onto a completed subtask phase; ``memory_at_step_k`` is the
|
||||
concatenation of milestones for completed phases.
|
||||
"""
|
||||
|
||||
def to_json(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"subtasks": list(self.subtasks),
|
||||
"memory_milestones": list(self.memory_milestones),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, payload: dict[str, Any]) -> Vocabulary:
|
||||
subtasks = tuple(
|
||||
str(s).strip() for s in (payload.get("subtasks") or []) if str(s).strip()
|
||||
)
|
||||
memory_milestones = tuple(
|
||||
str(s).strip() for s in (payload.get("memory_milestones") or []) if str(s).strip()
|
||||
)
|
||||
return cls(subtasks=subtasks, memory_milestones=memory_milestones)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.subtasks and not self.memory_milestones
|
||||
|
||||
|
||||
def vocabulary_path(root: Path) -> Path:
|
||||
"""Return the canonical on-disk location for the vocabulary file."""
|
||||
return root / "meta" / VOCABULARY_FILENAME
|
||||
|
||||
|
||||
def load_vocabulary(root: Path) -> Vocabulary | None:
|
||||
"""Read ``<root>/meta/canonical_vocabulary.json`` if present.
|
||||
|
||||
Returns ``None`` when the file does not exist — callers fall back to
|
||||
free-form (unconstrained) subtask + memory generation, preserving the
|
||||
pipeline's behaviour on datasets that never ran phase 0.
|
||||
"""
|
||||
path = vocabulary_path(root)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
logger.warning("could not read %s: %s — proceeding without vocabulary", path, exc)
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning("%s is not a JSON object — ignoring", path)
|
||||
return None
|
||||
vocab = Vocabulary.from_json(payload)
|
||||
if vocab.is_empty():
|
||||
return None
|
||||
return vocab
|
||||
|
||||
|
||||
def save_vocabulary(root: Path, vocab: Vocabulary) -> Path:
|
||||
"""Atomically persist ``vocab`` to ``<root>/meta/canonical_vocabulary.json``."""
|
||||
path = vocabulary_path(root)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(
|
||||
json.dumps(vocab.to_json(), indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
tmp.replace(path)
|
||||
return path
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabularyDiscoveryModule:
|
||||
"""Derive a dataset-level canonical vocabulary from sample episodes.
|
||||
|
||||
Phase 0 of the executor: pulls ``config.sample_episodes`` episode
|
||||
videos, packs them into one Qwen-VL multi-video prompt, and asks the
|
||||
model to enumerate the small set of canonical subtask labels +
|
||||
memory milestones that recur across them. The output is persisted
|
||||
to ``meta/canonical_vocabulary.json`` and consumed by phase 1.
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VocabularyConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def discover(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
*,
|
||||
existing: Vocabulary | None = None,
|
||||
) -> Vocabulary | None:
|
||||
"""Run vocabulary discovery against the first N sample episodes.
|
||||
|
||||
``existing`` short-circuits the VLM call when ``config.reuse_existing``
|
||||
is True and an on-disk vocabulary is already present — keeps re-runs
|
||||
cheap and lets operators hand-edit the file without it getting
|
||||
overwritten.
|
||||
"""
|
||||
if existing is not None and self.config.reuse_existing:
|
||||
logger.info(
|
||||
"vocabulary: reusing existing (%d subtasks, %d memory milestones)",
|
||||
len(existing.subtasks),
|
||||
len(existing.memory_milestones),
|
||||
)
|
||||
return existing
|
||||
|
||||
sample = list(records[: max(1, int(self.config.sample_episodes))])
|
||||
if not sample:
|
||||
return None
|
||||
|
||||
task_hint = next((r.episode_task for r in sample if r.episode_task), "")
|
||||
prompt = load_prompt("module_0_vocabulary").format(
|
||||
episode_task=task_hint or "(unspecified)",
|
||||
n_episodes=len(sample),
|
||||
)
|
||||
# Pack one video block per sample episode so the VLM sees the
|
||||
# variation across episodes (different starting poses, different
|
||||
# object placements) rather than overfitting to one trajectory.
|
||||
content: list[dict[str, Any]] = []
|
||||
for record in sample:
|
||||
video_frames = self.frame_provider.video_for_episode(
|
||||
record, int(self.config.max_video_frames_per_episode)
|
||||
)
|
||||
if video_frames:
|
||||
content.extend(to_video_block(video_frames))
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
logger.warning("vocabulary: VLM did not return a JSON object — skipping")
|
||||
return None
|
||||
|
||||
vocab = Vocabulary.from_json(result)
|
||||
if vocab.is_empty():
|
||||
logger.warning("vocabulary: VLM returned an empty vocabulary — skipping")
|
||||
return None
|
||||
logger.info(
|
||||
"vocabulary: discovered %d subtask labels + %d memory milestones from %d episodes",
|
||||
len(vocab.subtasks),
|
||||
len(vocab.memory_milestones),
|
||||
len(sample),
|
||||
)
|
||||
return vocab
|
||||
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
@@ -0,0 +1,356 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Surface a friendly error from the writer rather than letting
|
||||
# the raw KeyError bubble out of the dict access below — modules
|
||||
# are expected to always emit ``role``, but the validator
|
||||
# currently doesn't check this so a future bug would otherwise
|
||||
# be hard to triage.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"timestamp": float(row["timestamp"]),
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_rows_for_writer(
|
||||
rows: Iterable[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Helper used by tests/validators to partition a flat row list into
|
||||
(persistent_rows, event_rows) using ``column_for_style``.
|
||||
"""
|
||||
persistent: list[dict[str, Any]] = []
|
||||
events: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
return persistent, events
|
||||
@@ -199,12 +199,13 @@ class OpenCVCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
set_fourcc_after_size_and_fps = platform.system() == "Windows"
|
||||
if self.config.fourcc is not None and not set_fourcc_after_size_and_fps:
|
||||
self._validate_fourcc()
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
|
||||
@@ -222,6 +223,11 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
if self.config.fourcc is not None and set_fourcc_after_size_and_fps:
|
||||
# On Windows with DSHOW, changing the resolution can silently override the FOURCC setting.
|
||||
# Set FOURCC last to make sure the requested pixel format is actually enforced.
|
||||
self._validate_fourcc()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
@@ -430,7 +436,7 @@ class OpenCVCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame (blocking call)
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
@@ -439,9 +445,8 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
@@ -479,8 +484,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width, 1)
|
||||
of type `np.uint16` (raw depth values in millimeters).
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color/depth frame (blocking call with 10s timeout)
|
||||
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -474,9 +474,8 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
@@ -487,8 +486,6 @@ class RealSenseCamera(Camera):
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
|
||||
processed_depth_frame = processed_depth_frame[..., np.newaxis]
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -525,8 +522,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive(): # pragma: no cover
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
@@ -537,6 +532,7 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -579,6 +575,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -614,71 +611,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""Read the latest depth frame asynchronously, in metric meters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in metric meters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.uint16`` of shape ``(H, W, 1)`` in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
@@ -249,9 +249,8 @@ class ZMQCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
@@ -293,8 +292,6 @@ class ZMQCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -205,3 +205,149 @@ class WandBLogger:
|
||||
|
||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
def log_training_examples(
|
||||
self,
|
||||
batch: dict,
|
||||
step: int,
|
||||
*,
|
||||
camera_keys: list[str],
|
||||
n_samples: int = 4,
|
||||
policy=None,
|
||||
predict_actions: bool = False,
|
||||
mode: str = "train",
|
||||
) -> None:
|
||||
"""Push a ``wandb.Table`` of training-example rows for the current batch.
|
||||
|
||||
Each row is one batch element with:
|
||||
* one ``wandb.Image`` column per camera in ``camera_keys`` (CHW or
|
||||
HWC, uint8 or float in [0,1] — auto-detected),
|
||||
* any text fields present in the batch (``task`` / ``subtask`` /
|
||||
``memory`` / ``instruction``),
|
||||
* ground-truth action first/last frame (the action chunk's
|
||||
endpoints — gives a quick sense of trajectory direction),
|
||||
* if ``predict_actions=True`` and ``policy`` is supplied, the model's
|
||||
``predict_action_chunk`` first/last frame alongside.
|
||||
|
||||
This is opt-in via ``--wandb.log_examples_freq=N`` on the CLI; the
|
||||
training loop calls it once every N steps. Cheap to keep on: with
|
||||
N=4 samples and 3 cameras you upload 12 small PNGs per dump and (if
|
||||
enabled) run one extra inference forward pass.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
import numpy as np # noqa: PLC0415
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
# Batch size — first tensor-like value wins.
|
||||
bsz = next(
|
||||
(int(v.shape[0]) for v in batch.values() if hasattr(v, "shape") and v.ndim > 0),
|
||||
None,
|
||||
)
|
||||
if not bsz:
|
||||
return
|
||||
n = min(int(n_samples), bsz)
|
||||
|
||||
# Optional predicted-action forward pass on the first n samples.
|
||||
pred_actions: np.ndarray | None = None
|
||||
if predict_actions and policy is not None:
|
||||
was_training = policy.training
|
||||
try:
|
||||
policy.eval()
|
||||
sub_batch = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
sub_batch[k] = v[:n]
|
||||
elif isinstance(v, (list, tuple)):
|
||||
sub_batch[k] = list(v[:n])
|
||||
else:
|
||||
sub_batch[k] = v
|
||||
with torch.no_grad():
|
||||
pred = policy.predict_action_chunk(sub_batch)
|
||||
pred_actions = pred.detach().cpu().float().numpy()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: predict_action_chunk failed (%s) — "
|
||||
"skipping predicted-action columns",
|
||||
exc,
|
||||
)
|
||||
pred_actions = None
|
||||
finally:
|
||||
if was_training:
|
||||
policy.train()
|
||||
|
||||
present_cameras = [c for c in camera_keys if c in batch]
|
||||
text_keys = [k for k in ("task", "subtask", "memory", "instruction") if k in batch]
|
||||
|
||||
columns = ["sample"]
|
||||
columns.extend(c.removeprefix("observation.images.") or c for c in present_cameras)
|
||||
columns.extend(text_keys)
|
||||
columns.append("gt_action_first")
|
||||
columns.append("gt_action_last")
|
||||
if pred_actions is not None:
|
||||
columns.append("pred_action_first")
|
||||
columns.append("pred_action_last")
|
||||
|
||||
table = self._wandb.Table(columns=columns)
|
||||
|
||||
def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray:
|
||||
# Strip an outer time dim if present: (T, C, H, W) -> first frame.
|
||||
if t.ndim == 4:
|
||||
t = t[0]
|
||||
# CHW -> HWC.
|
||||
if t.ndim == 3 and t.shape[0] in (1, 3, 4) and t.shape[-1] not in (1, 3, 4):
|
||||
t = t.permute(1, 2, 0)
|
||||
arr = t.detach().cpu().float().numpy()
|
||||
if arr.size and float(arr.max()) <= 1.5:
|
||||
arr = arr * 255.0
|
||||
return np.clip(arr, 0, 255).astype(np.uint8)
|
||||
|
||||
def _action_endpoints(a: torch.Tensor) -> tuple[str, str]:
|
||||
arr = a.detach().cpu().float().numpy()
|
||||
if arr.ndim == 2: # (T, D)
|
||||
return (
|
||||
str(np.round(arr[0], 3).tolist()),
|
||||
str(np.round(arr[-1], 3).tolist()),
|
||||
)
|
||||
if arr.ndim == 1:
|
||||
rounded = np.round(arr, 3).tolist()
|
||||
return (str(rounded), str(rounded))
|
||||
return (str(arr.tolist()), str(arr.tolist()))
|
||||
|
||||
for i in range(n):
|
||||
row: list = [i]
|
||||
for cam in present_cameras:
|
||||
try:
|
||||
row.append(self._wandb.Image(_to_uint8_hwc(batch[cam][i])))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: camera %s sample %d failed (%s)",
|
||||
cam,
|
||||
i,
|
||||
exc,
|
||||
)
|
||||
row.append(None)
|
||||
for tk in text_keys:
|
||||
v = batch[tk]
|
||||
if isinstance(v, (list, tuple)):
|
||||
row.append(str(v[i]) if i < len(v) else "")
|
||||
else:
|
||||
row.append(str(v))
|
||||
action = batch.get("action")
|
||||
if isinstance(action, torch.Tensor) and action.ndim >= 1:
|
||||
first, last = _action_endpoints(action[i])
|
||||
row.append(first)
|
||||
row.append(last)
|
||||
else:
|
||||
row.append("")
|
||||
row.append("")
|
||||
if pred_actions is not None:
|
||||
p = torch.from_numpy(pred_actions[i])
|
||||
pfirst, plast = _action_endpoints(p)
|
||||
row.append(pfirst)
|
||||
row.append(plast)
|
||||
table.add_data(*row)
|
||||
|
||||
self._wandb.log({f"{mode}/examples": table}, step=step)
|
||||
|
||||
@@ -24,6 +24,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -34,10 +35,8 @@ from .types import (
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -51,14 +50,15 @@ __all__ = [
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
"DepthEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,8 +60,6 @@ class DatasetRecordConfig:
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -62,6 +62,72 @@ class WandBConfig:
|
||||
run_id: str | None = None
|
||||
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||
add_tags: bool = True # If True, save configuration as tags in the WandB run.
|
||||
# Periodic training-example dump (independent of ``log_freq``). When > 0,
|
||||
# every ``log_examples_freq`` steps the trainer pushes a ``wandb.Table``
|
||||
# with one row per sampled batch element containing each camera view
|
||||
# (rendered as ``wandb.Image``), any text fields present in the batch
|
||||
# (``task`` / ``subtask`` / ``memory`` / ``instruction``), and the
|
||||
# ground-truth action chunk's first + last frames. Defaults to 5000 — set
|
||||
# to 0 to disable. Only fires when ``enable=True``, so runs without wandb
|
||||
# are unaffected.
|
||||
log_examples_freq: int = 5000
|
||||
# Number of batch elements to include in each example dump.
|
||||
log_examples_n: int = 4
|
||||
# If True (default), also run ``policy.predict_action_chunk`` on the logged
|
||||
# samples (in eval mode, no_grad) and add predicted vs ground-truth action
|
||||
# columns to the table. Costs one extra forward pass per dump — negligible
|
||||
# at the 5k-step default cadence. Set to ``False`` if your policy doesn't
|
||||
# implement ``predict_action_chunk`` or you want to skip the extra forward.
|
||||
log_examples_predict_actions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EMAConfig:
|
||||
"""Exponential Moving Average of trainable policy parameters.
|
||||
|
||||
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
|
||||
pi052) benefit substantially from averaging late-training
|
||||
parameter oscillations — see Chi et al. 2023 §V.D. The official
|
||||
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
|
||||
``0.999`` for its pi05_libero config; the openpi PyTorch port
|
||||
explicitly lists EMA as unsupported, and LeRobot main inherited
|
||||
that gap. Enabling this flag plugs ema-pytorch
|
||||
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
|
||||
training loop with a shadow ``nn.Module`` clone of the policy.
|
||||
|
||||
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
|
||||
params) + one elementwise update per training step (~1% step time).
|
||||
|
||||
On by default — matches openpi (JAX) which ships EMA on for every
|
||||
config, and closes the gap with the openpi PyTorch port which
|
||||
explicitly lists EMA as unsupported. Set ``--ema.enable=false`` to
|
||||
disable for short runs / memory-constrained training where the
|
||||
extra fp32 shadow copy is the bottleneck.
|
||||
"""
|
||||
|
||||
enable: bool = True
|
||||
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
|
||||
# ema-pytorch as ``beta``).
|
||||
# 0.999 — last ~1000 steps; pi05_libero default in openpi
|
||||
# 0.99 — last ~100 steps; openpi top-level default
|
||||
# 0.75 — very fast EMA (Diffusion Policy original setting)
|
||||
# 0.9999 — very slow EMA (long classification runs)
|
||||
decay: float = 0.999
|
||||
# Skip the first N calls to ``ema.update()``; during this window
|
||||
# the shadow is just a hard copy of the live weights (no averaging).
|
||||
# Lets early-training rapid changes settle before averaging begins.
|
||||
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
|
||||
# ramp like older lerobot EMA implementations).
|
||||
warmup_steps: int = 0
|
||||
# When True, the periodic eval block uses the EMA shadow model
|
||||
# directly (``ema.ema_model``) instead of the live policy. Standard
|
||||
# practice for diffusion-style policies — eval scores are usually
|
||||
# 1–3% higher than the live policy at the same step.
|
||||
use_for_eval: bool = True
|
||||
# When True, the periodic wandb training-example dump uses the EMA
|
||||
# shadow for the optional predicted-action columns (so what you see
|
||||
# in W&B matches eval behavior).
|
||||
use_for_wandb_examples: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
221
src/lerobot/configs/recipe.py
Normal file
221
src/lerobot/configs/recipe.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and the recipe supervises something.
|
||||
|
||||
A recipe is valid if it has at least one of:
|
||||
|
||||
* a ``target: true`` assistant turn (drives text-CE supervision), or
|
||||
* a ``stream: low_level`` turn (drives flow / action supervision via
|
||||
``predict_actions=True``, even when no assistant turn is targeted —
|
||||
e.g. π0.5-style ``low_level_execution`` where the action expert
|
||||
conditions on a user-only ``${subtask}`` prompt).
|
||||
"""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
has_target = any(turn.target for turn in self.messages)
|
||||
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
|
||||
if not (has_target or has_low_level):
|
||||
raise ValueError(
|
||||
"Message recipes must contain at least one supervised turn — "
|
||||
"either ``target: true`` (text CE) or ``stream: low_level`` "
|
||||
"(flow/action loss)."
|
||||
)
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
68
src/lerobot/configs/recipes/subtask_mem.yaml
Normal file
68
src/lerobot/configs/recipes/subtask_mem.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.30
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.55
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into a
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on this dataset.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
99
src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml
Normal file
99
src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml
Normal file
@@ -0,0 +1,99 @@
|
||||
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
|
||||
#
|
||||
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
|
||||
# camera-grounded VQA across the three RoboCasa camera keys produced
|
||||
# by ``slurm_build_robocasa_composite_seen.py``:
|
||||
#
|
||||
# observation.images.robot0_agentview_left (left scene view)
|
||||
# observation.images.robot0_agentview_right (right scene view)
|
||||
# observation.images.robot0_eye_in_hand (wrist)
|
||||
#
|
||||
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
|
||||
# VQA per camera, so each anchor frame produces three (user, assistant)
|
||||
# rows tagged with their source camera. Each VQA sub-recipe consumes
|
||||
# the rows for one camera via ``camera=...`` resolver bindings.
|
||||
#
|
||||
# Spatial VQA targets (bbox / point) are rewritten from JSON to
|
||||
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
|
||||
# ``register_paligemma_loc_tokens`` already collapses them to single
|
||||
# detection-vocab ids so the LM head learns the pretrained pointing /
|
||||
# detection prior, not a 7-piece BPE salad.
|
||||
#
|
||||
# Interjections / spoken responses are intentionally absent — the
|
||||
# annotation job runs with ``--interjections.enabled=false``.
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.45
|
||||
messages:
|
||||
# Action expert is conditioned on the SUBTASK; at inference the
|
||||
# high-level loop generates it via the LM head and feeds it here.
|
||||
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
|
||||
# loss fires; subtask CE is owned by ``high_level_subtask``.
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# Trained densely with ``active_at`` — every frame inside a subtask
|
||||
# interval — so the (prior_memory, completed_subtask) → current_memory
|
||||
# mapping is supervised against varied observations. The *when* to
|
||||
# emit lives in the inference trigger (subtask_change), not the
|
||||
# model. See ``subtask_mem.yaml`` for the long version of this note.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
ask_vqa_agentview_left:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_left}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_agentview_right:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_right}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_eye_in_hand}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
114
src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml
Normal file
114
src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml
Normal file
@@ -0,0 +1,114 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into the
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on Hi-Robot-style datasets.
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
# Spoken reply only: the assistant turn carries no text content,
|
||||
# just a `say` tool call (`tool_calls_from: speech`). The chat
|
||||
# tokenizer flattens it to a `<say>...</say>` marker, so the
|
||||
# supervised target trains the model to respond to an
|
||||
# interjection with a spoken acknowledgement.
|
||||
- {role: assistant, stream: high_level, target: true, if_present: speech, tool_calls_from: speech}
|
||||
|
||||
# VQA is view-dependent — each camera gets its own sub-recipe so the
|
||||
# resolver disambiguates via `camera=...`. Camera keys match
|
||||
# subtasks_vqa.yaml (`front` + `wrist`); adjust to your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
61
src/lerobot/configs/recipes/subtasks_vqa.yaml
Normal file
61
src/lerobot/configs/recipes/subtasks_vqa.yaml
Normal file
@@ -0,0 +1,61 @@
|
||||
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
|
||||
#
|
||||
# Trains two things only: subtasks and VQA. Plan and memory are
|
||||
# intentionally left out — keeps the prompt short and the training
|
||||
# surface small. The fuller blend with memory + spoken replies is
|
||||
# ``subtask_mem_vqa_speech.yaml``.
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# PI052's text tokenizer renders these messages as plain
|
||||
# ``Role: content`` text (PaliGemma is not chat-pretrained).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.40
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# the high-level loop (``HighLevelSubtaskFwd``) generates the
|
||||
# subtask via the LM head and feeds it here. The action expert's
|
||||
# prefix is [images, subtask, state]. ``stream: low_level`` flips
|
||||
# ``predict_actions=True`` so the flow loss fires; no text-CE
|
||||
# target here (subtask prediction is owned by
|
||||
# ``high_level_subtask``).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -111,9 +111,20 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
ema: EMAConfig = field(default_factory=EMAConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
|
||||
# carrying a `vqa` language annotation often enough that they make
|
||||
# up roughly this fraction of the training stream. VQA annotations
|
||||
# are typically sparse, so without this they are underrepresented.
|
||||
# `None` (default) keeps uniform episode-aware sampling.
|
||||
vqa_target_fraction: float | None = None
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training). Old
|
||||
# inline ``use_rabc`` / ``rabc_*`` params are migrated to this
|
||||
# field by ``_migrate_legacy_rabc_keys`` above.
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
|
||||
@@ -19,8 +19,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, ClassVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
@@ -36,12 +36,11 @@ HW_VIDEO_CODECS = [
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
|
||||
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
|
||||
)
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
@@ -53,19 +52,6 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
# Default depth quantization and encoding parameters.
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
|
||||
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
@@ -100,10 +86,6 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Source-data channel count this encoder is expected to handle (3 for RGB,
|
||||
# 1 for depth, etc.)
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 3
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
@@ -156,9 +138,7 @@ class VideoEncoderConfig:
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(
|
||||
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
|
||||
)
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
@@ -238,10 +218,6 @@ class VideoEncoderConfig:
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
@@ -257,59 +233,3 @@ class VideoEncoderConfig:
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantizer.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"gray12le"``.
|
||||
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "gray12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig:
|
||||
"""Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block.
|
||||
|
||||
Reuses :meth:`VideoEncoderConfig.from_video_info` for the base
|
||||
codec/tuning fields and then layers the depth-specific tuning
|
||||
(``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top.
|
||||
Missing keys fall back to the class defaults.
|
||||
"""
|
||||
base = VideoEncoderConfig.from_video_info(video_info)
|
||||
kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init}
|
||||
|
||||
video_info = video_info or {}
|
||||
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{name}")
|
||||
if value is not None:
|
||||
kwargs[name] = value
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
@@ -31,21 +31,42 @@ from .dataset_tools import (
|
||||
modify_features,
|
||||
modify_tasks,
|
||||
recompute_stats,
|
||||
reencode_dataset,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
|
||||
def make_dataset(*args, **kwargs):
|
||||
from .factory import make_dataset as _make_dataset
|
||||
|
||||
return _make_dataset(*args, **kwargs)
|
||||
|
||||
|
||||
def resolve_delta_timestamps(*args, **kwargs):
|
||||
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
|
||||
|
||||
return _resolve_delta_timestamps(*args, **kwargs)
|
||||
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
# Import directly: ``from lerobot.datasets.io_utils import ...``
|
||||
@@ -54,10 +75,16 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"WeightedEpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"check_video_encoder_parameters_pyav",
|
||||
@@ -69,6 +96,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
@@ -77,6 +105,7 @@ __all__ = [
|
||||
"modify_features",
|
||||
"modify_tasks",
|
||||
"recompute_stats",
|
||||
"reencode_dataset",
|
||||
"remove_feature",
|
||||
"resolve_delta_timestamps",
|
||||
"safe_stop_image_writer",
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
@@ -550,10 +550,8 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
|
||||
raise ValueError(
|
||||
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
|
||||
)
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
|
||||
@@ -36,12 +36,12 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
check_version_compatibility,
|
||||
@@ -177,7 +177,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -338,30 +337,54 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos or images.
|
||||
|
||||
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
||||
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
||||
"""
|
||||
|
||||
def _is_depth(ft: dict) -> bool:
|
||||
info = ft.get("info") or {}
|
||||
video_info = ft.get("video_info") or {}
|
||||
return (
|
||||
info.get("is_depth_map", False)
|
||||
or info.get("video.is_depth_map", False)
|
||||
or video_info.get("video.is_depth_map", False)
|
||||
)
|
||||
|
||||
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def has_language_columns(self) -> bool:
|
||||
"""Return ``True`` if the dataset declares any language column.
|
||||
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@tools.setter
|
||||
def tools(self, value: list[dict] | None) -> None:
|
||||
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
|
||||
|
||||
Writes ``value`` into the on-disk ``info.json`` (or clears the
|
||||
``tools`` key when ``value`` is ``None`` or empty), then reloads
|
||||
``self.info`` so the in-memory metadata matches what's on disk.
|
||||
Saves callers from hand-editing ``info.json`` and re-instantiating
|
||||
the metadata object.
|
||||
"""
|
||||
self.info.tools = [dict(t) for t in value] if value else None
|
||||
write_info(self.info, self.root)
|
||||
self.info = load_info(self.root)
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -557,7 +580,7 @@ class LeRobotDatasetMetadata:
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
@@ -577,13 +600,9 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
existing = self.features[key].get("info") or {}
|
||||
# Skip only if real video info has already been written. The ``is_depth_map`` entry (created at feature creation) is not blocking.
|
||||
if set(existing.keys()) - {"is_depth_map"}:
|
||||
continue
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
new_info = get_video_info(video_path, video_encoder=video_encoder)
|
||||
self.info.features[key]["info"] = {**existing, **new_info}
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -694,7 +713,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -22,10 +22,7 @@ from pathlib import Path
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import DepthEncoderConfig
|
||||
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
@@ -89,12 +86,6 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
|
||||
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
|
||||
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
@@ -135,10 +126,53 @@ class DatasetReader:
|
||||
def _load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self._meta.features)
|
||||
# Datasets annotated with the PR1 language columns may have been
|
||||
# written without registering those columns in ``meta/info.json``
|
||||
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
|
||||
# back-filled by ``lerobot-annotate``). Probe a single parquet
|
||||
# shard and graft the column features on so the strict
|
||||
# ``Dataset.from_parquet`` cast doesn't fail with
|
||||
# ``column names don't match``.
|
||||
features = self._extend_features_with_language_columns(features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def _extend_features_with_language_columns(
|
||||
self, features: datasets.Features
|
||||
) -> datasets.Features:
|
||||
"""Add ``language_persistent`` / ``language_events`` to ``features``
|
||||
when the underlying parquet shards declare them but the metadata
|
||||
doesn't. No-op when neither column is present or both are
|
||||
already registered.
|
||||
"""
|
||||
# Find any one parquet to peek at; bail if there are none yet
|
||||
# (the dataset will fail later for an unrelated reason and we
|
||||
# want that error to surface as-is).
|
||||
try:
|
||||
sample = next((self.root / "data").glob("*/*.parquet"))
|
||||
except StopIteration:
|
||||
return features
|
||||
|
||||
from pyarrow import parquet as _pq # noqa: PLC0415
|
||||
|
||||
schema_names = set(_pq.read_schema(sample).names)
|
||||
from .language import ( # noqa: PLC0415
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
|
||||
extra: dict[str, object] = {}
|
||||
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
|
||||
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
|
||||
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
|
||||
extra[LANGUAGE_EVENTS] = language_events_column_feature()
|
||||
if not extra:
|
||||
return features
|
||||
return datasets.Features({**features, **extra})
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
@@ -256,18 +290,7 @@ class DatasetReader:
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
is_depth=vid_key in self._meta.depth_keys,
|
||||
)
|
||||
if vid_key in self._meta.depth_keys:
|
||||
depth_encoder = self._depth_encoder_configs[vid_key]
|
||||
frames = dequantize_depth(
|
||||
frames,
|
||||
depth_min=depth_encoder.depth_min,
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_tensor=True,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
@@ -315,9 +338,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -26,7 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -61,11 +61,13 @@ from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
VIDEO_DIR,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
|
||||
|
||||
@@ -1329,7 +1331,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
video_encoder=camera_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1813,7 +1815,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
video_encoder=camera_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1860,7 +1862,7 @@ def convert_image_to_video_dataset(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, video_encoder=camera_encoder
|
||||
video_path, camera_encoder=camera_encoder
|
||||
)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
@@ -1884,3 +1886,83 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
return video_path
|
||||
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Re-encode every video in a dataset with a new set of encoding parameters.
|
||||
|
||||
Videos are re-encoded in-place and the video information in ``info.json`` is refreshed.
|
||||
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
sequential (no multiprocessing); ``1+`` spawns a
|
||||
:class:`~concurrent.futures.ProcessPoolExecutor`.
|
||||
|
||||
Returns:
|
||||
The same :class:`LeRobotDataset` instance with its metadata updated
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_paths_list = []
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_paths_list) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
for future in tqdm(
|
||||
as_completed(futures),
|
||||
total=len(futures),
|
||||
desc="Re-encoding videos",
|
||||
):
|
||||
future.result()
|
||||
else:
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -31,12 +31,7 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
@@ -53,7 +48,6 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
@@ -73,22 +67,17 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
path_template = (
|
||||
DEFAULT_DEPTH_PATH
|
||||
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
|
||||
else DEFAULT_IMAGE_PATH
|
||||
)
|
||||
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
video_encoder=video_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -108,7 +97,6 @@ class DatasetWriter:
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
@@ -122,8 +110,6 @@ class DatasetWriter:
|
||||
root: Local dataset root directory.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
depth_encoder: Video encoder settings applied to all **depth** cameras.
|
||||
``None`` uses :func:`~lerobot.configs.depth_encoder_defaults`.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
@@ -135,7 +121,6 @@ class DatasetWriter:
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -160,8 +145,7 @@ class DatasetWriter:
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -211,7 +195,6 @@ class DatasetWriter:
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self._meta.video_keys),
|
||||
depth_video_keys=set(self._meta.video_keys) & set(self._meta.depth_keys),
|
||||
temp_dir=self._root,
|
||||
)
|
||||
|
||||
@@ -267,7 +250,14 @@ class DatasetWriter:
|
||||
for key, ft in self._meta.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
stacked_values = np.stack(episode_buffer[key])
|
||||
|
||||
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
|
||||
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
|
||||
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
|
||||
stacked_values = stacked_values.reshape(episode_length)
|
||||
|
||||
episode_buffer[key] = stacked_values
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
@@ -310,9 +300,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -523,12 +511,7 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(
|
||||
video_key,
|
||||
video_encoder=self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
)
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -595,14 +578,13 @@ class DatasetWriter:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
|
||||
is_depth = video_key in self._meta.depth_keys
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._depth_encoder if is_depth else self._camera_encoder,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_PIX_FMT,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.integer] | NDArray[np.floating],
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
|
||||
resolved_unit = (
|
||||
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
|
||||
)
|
||||
return depth.astype(np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
video_backend: Video backend to use for encoding. Defaults to "pyav".
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
|
||||
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
|
||||
depth = depth.squeeze()
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
|
||||
|
||||
if video_backend == "pyav":
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
|
||||
write_u16_plane(frame.planes[0], quantized)
|
||||
return frame
|
||||
else:
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_tensor: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
|
||||
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
output_tensor: If True, return a torch.Tensor instead of a numpy array.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=pix_fmt)
|
||||
|
||||
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
|
||||
|
||||
depth_min_m = np.float32(depth_min)
|
||||
depth_max_m = np.float32(depth_max)
|
||||
shift_m = np.float32(shift)
|
||||
|
||||
# The de-normalization and de-quantization is performed in meters (convenience choice).
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_m + shift_m))
|
||||
log_max = math.log(float(depth_max_m + shift_m))
|
||||
depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m
|
||||
else:
|
||||
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
|
||||
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
|
||||
|
||||
# Add single-channel dim: (H, W) → (H, W, 1)
|
||||
if depth_m.ndim == 2:
|
||||
depth_m = depth_m[..., np.newaxis]
|
||||
|
||||
# Return depth as float32 meters.
|
||||
if output_unit == "m":
|
||||
return torch.from_numpy(depth_m) if output_tensor else depth_m
|
||||
|
||||
# Return depth as uint16 millimeters.
|
||||
mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False)
|
||||
if output_tensor:
|
||||
# torch.uint16 support is very limited, we convert to float32 instead.
|
||||
return torch.from_numpy(mm.astype(np.float32))
|
||||
else:
|
||||
return mm
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import datasets
|
||||
@@ -23,6 +24,12 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -47,7 +54,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -278,6 +291,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return validate_feature_language(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
@@ -321,7 +336,7 @@ def validate_feature_image_or_video(
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
@@ -357,6 +372,30 @@ def validate_feature_string(name: str, value: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def validate_feature_language(name: str, value) -> str:
|
||||
"""Validate a feature that is expected to hold language annotations.
|
||||
|
||||
Language columns (``language_persistent`` / ``language_events``) are
|
||||
populated after recording by the annotation pipeline, not at record time.
|
||||
Any value supplied here is dropped before the frame is written, so a
|
||||
non-empty value almost certainly signals a mistake. We warn rather than
|
||||
fail to keep recording resilient.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
str: Always an empty string — language values are non-fatal.
|
||||
"""
|
||||
if value is not None:
|
||||
logging.warning(
|
||||
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
|
||||
f"not at record time. The provided value will be dropped."
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
|
||||
@@ -42,41 +42,10 @@ def safe_stop_image_writer(func):
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
# TODO(CarolinePascal): 4 dimensions RGB-D images
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in [np.uint16, np.float32]:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
|
||||
)
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
@@ -102,28 +71,13 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
the save operation.
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -147,7 +101,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -186,14 +186,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -265,11 +257,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -304,8 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
256
src/lerobot/datasets/language.py
Normal file
256
src/lerobot/datasets/language.py
Normal file
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF feature used for tool-call payloads.
|
||||
|
||||
Older ``datasets`` versions do not expose ``datasets.Json``. The
|
||||
annotation pipeline currently emits the canonical ``say`` tool call
|
||||
shape, so use that explicit struct instead of falling back to a string
|
||||
that cannot cast structured parquet values.
|
||||
"""
|
||||
if hasattr(datasets, "Json"):
|
||||
return datasets.Json()
|
||||
return {
|
||||
"type": datasets.Value("string"),
|
||||
"function": {
|
||||
"name": datasets.Value("string"),
|
||||
"arguments": {"text": datasets.Value("string")},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
|
||||
uses for frame data.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float32(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float32"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
631
src/lerobot/datasets/language_render.py
Normal file
631
src/lerobot/datasets/language_render.py
Normal file
@@ -0,0 +1,631 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
|
||||
is also a float32 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz). This does mean two persistent
|
||||
rows of the same selector emitted within 0.1 s of each other cannot be
|
||||
told apart by ``emitted_at`` — acceptable because persistent annotations
|
||||
(subtask / plan / memory transitions) change on a human-action timescale,
|
||||
not at the camera frame rate."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||
we use a tolerance instead of bit-equality). For event styles, the
|
||||
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||
(event rows carry no timestamp of their own), so all matching event rows
|
||||
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||
``camera`` field — required to disambiguate when multiple view-dependent
|
||||
rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||
]
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
|
||||
# VQA-priority routing. A ``vqa`` annotation is sparse and
|
||||
# view-dependent; the plain weighted blend would (a) waste a draw
|
||||
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
|
||||
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
|
||||
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
|
||||
# sub-recipes and *this* frame carries one of their VQA bindings,
|
||||
# render VQA here regardless of the weighted draw. That makes VQA's
|
||||
# recipe-side training share equal the VQA-annotation density (the
|
||||
# maximum reachable without a dataset-level oversampling sampler).
|
||||
if recipe.blend is not None:
|
||||
vqa_rendered = _render_vqa_if_present(
|
||||
recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
if vqa_rendered is not None:
|
||||
return vqa_rendered
|
||||
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _render_vqa_if_present(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
|
||||
annotation; otherwise return ``None`` so the caller falls back to the
|
||||
normal weighted blend.
|
||||
|
||||
When several VQA sub-recipes resolve (e.g. a frame annotated for more
|
||||
than one camera), one is chosen deterministically by relative weight.
|
||||
"""
|
||||
assert recipe.blend is not None
|
||||
renderable: list[tuple[float, RenderedMessages]] = []
|
||||
for name, component in recipe.blend.items():
|
||||
if not name.startswith("ask_vqa"):
|
||||
continue
|
||||
bindings = _resolve_bindings(
|
||||
component,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
rendered = _render_message_recipe(component, bindings)
|
||||
if rendered is not None:
|
||||
renderable.append((float(component.weight or 0.0), rendered))
|
||||
|
||||
if not renderable:
|
||||
return None
|
||||
if len(renderable) == 1:
|
||||
return renderable[0][1]
|
||||
|
||||
# Multiple cameras have a VQA for this frame — deterministic pick by
|
||||
# relative weight (fall back to a uniform draw if all weights are 0).
|
||||
total = sum(w for w, _ in renderable) or float(len(renderable))
|
||||
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total
|
||||
cumulative = 0.0
|
||||
for w, rendered in renderable:
|
||||
cumulative += w or (total / len(renderable))
|
||||
if draw < cumulative:
|
||||
return rendered
|
||||
return renderable[-1][1]
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
# A render is meaningful if it supervises *something*: either a
|
||||
# text-CE target turn, or a ``low_level`` stream turn (flow / action
|
||||
# supervision — e.g. the flow-only ``low_level_execution`` recipe,
|
||||
# ``user(${subtask})`` with ``stream: low_level`` and no target).
|
||||
# Without this, a flow-only recipe renders to ``None`` every time
|
||||
# the blend draws it → ``predict_actions`` is never True → the
|
||||
# action expert never receives a flow loss.
|
||||
has_low_level = any(stream == "low_level" for stream in streams)
|
||||
if not target_indices and not has_low_level:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
# Valid iff it supervises something: a text-CE target turn OR a
|
||||
# ``low_level`` stream turn (flow / action supervision).
|
||||
if not target_indices and not any(s == "low_level" for s in streams):
|
||||
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the resolver is ambiguous.
|
||||
|
||||
Multiple matches always raise — even when the caller already passed
|
||||
some selectors — because remaining ambiguity means the data has
|
||||
several rows that look identical to the resolver and the caller
|
||||
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
|
||||
rows shared across cameras).
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r} role={role!r} "
|
||||
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
|
||||
f"Add a selector that distinguishes them."
|
||||
)
|
||||
return rows[0]
|
||||
|
||||
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
return float(unwrap_scalar(row["timestamp"]))
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -24,7 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -60,7 +60,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -187,9 +186,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
@@ -277,7 +273,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
depth_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
@@ -285,7 +280,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -328,14 +322,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -653,7 +645,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -686,8 +677,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
@@ -731,13 +720,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -761,7 +749,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
@@ -791,8 +778,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
@@ -839,13 +824,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -24,7 +24,6 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,22 +31,6 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_pix_fmt_channels(pix_fmt: str) -> int:
|
||||
"""Return the number of components (channels) for *pix_fmt*."""
|
||||
return len(av.VideoFormat(pix_fmt).components)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
@@ -159,16 +142,6 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
|
||||
"""Ensure *pix_fmt* can carry at least *channels* components."""
|
||||
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
|
||||
if pix_fmt_channels < channels:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
|
||||
f"but the source data has {channels} channel(s)."
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
@@ -183,18 +156,12 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, Any],
|
||||
channels: int | None = None,
|
||||
) -> None:
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
@@ -204,6 +171,4 @@ def check_video_encoder_parameters_pyav(
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
if channels is not None:
|
||||
_check_pix_fmt_channels(pix_fmt, channels)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
|
||||
@@ -84,3 +84,66 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
|
||||
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
|
||||
proportion to per-frame weights.
|
||||
|
||||
Used to oversample frames carrying a sparse annotation (e.g. a VQA
|
||||
question) so the policy sees them more often than their natural
|
||||
dataset density. One epoch still yields ``len(self.indices)``
|
||||
samples — the weights only change the *composition* of the stream,
|
||||
not its length. Each epoch re-draws, so the oversampled subset
|
||||
varies run to run.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
frame_weights,
|
||||
*,
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
|
||||
dataset_to_indices: Episode end indices.
|
||||
frame_weights: 1-D sequence/tensor of non-negative weights, one per
|
||||
dataset frame (length == total dataset frames). Higher weight ⇒
|
||||
that frame is sampled more often.
|
||||
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
|
||||
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
|
||||
frame filtering is applied first, then weighting is restricted
|
||||
to the surviving frames.
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_from_indices,
|
||||
dataset_to_indices,
|
||||
episode_indices_to_use=episode_indices_to_use,
|
||||
drop_n_first_frames=drop_n_first_frames,
|
||||
drop_n_last_frames=drop_n_last_frames,
|
||||
shuffle=False,
|
||||
)
|
||||
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
|
||||
idx = torch.tensor(self.indices, dtype=torch.long)
|
||||
if weights.numel() <= int(idx.max()):
|
||||
raise ValueError(
|
||||
f"frame_weights has {weights.numel()} entries but the sampler "
|
||||
f"references frame index {int(idx.max())}."
|
||||
)
|
||||
selected = weights[idx]
|
||||
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
|
||||
raise ValueError("frame_weights must be finite and non-negative.")
|
||||
if float(selected.sum()) <= 0.0:
|
||||
# All surviving frames have zero weight — fall back to uniform.
|
||||
selected = torch.ones_like(selected)
|
||||
self._weights = selected
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
|
||||
for i in picks.tolist():
|
||||
yield self.indices[i]
|
||||
|
||||
@@ -88,12 +88,10 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
@@ -131,6 +129,9 @@ class DatasetInfo:
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||
tools: list[dict] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
@@ -152,11 +153,15 @@ class DatasetInfo:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
Drops ``tools`` when unset so existing datasets keep a clean
|
||||
``info.json``.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
if d.get("tools") is None:
|
||||
d.pop("tools", None)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
@@ -361,17 +366,24 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
msg = (
|
||||
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
|
||||
f"either doesn't exist on the Hub yet, or it was uploaded "
|
||||
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
|
||||
f" from huggingface_hub import HfApi\n"
|
||||
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
|
||||
)
|
||||
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
|
||||
# ``__init__`` indexes ``response.headers`` unconditionally on
|
||||
# current ``huggingface_hub`` versions. Constructing it without
|
||||
# a real ``Response`` object crashes with either
|
||||
# ``TypeError: missing 1 required keyword-only argument`` (old
|
||||
# builds) or ``AttributeError: 'NoneType' object has no attribute
|
||||
# 'headers'`` (new builds). Skip that path entirely — this isn't
|
||||
# really an HTTP error, it's a configuration issue — and raise a
|
||||
# plain ``RuntimeError`` so the message actually reaches the
|
||||
# caller.
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
@@ -17,11 +17,13 @@ import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
@@ -37,16 +39,11 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
from .pyav_utils import get_pix_fmt_channels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +53,6 @@ def decode_video_frames(
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -76,11 +72,6 @@ def decode_video_frames(
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
@@ -100,7 +91,6 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
@@ -150,13 +140,9 @@ def decode_video_frames_pyav(
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
if is_depth:
|
||||
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
|
||||
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
|
||||
else:
|
||||
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
@@ -207,15 +193,70 @@ def decode_video_frames_pyav(
|
||||
return closest_frames
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||
DEFAULT_DECODER_CACHE_SIZE = 100
|
||||
"""Default LRU capacity for :class:`VideoDecoderCache`.
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||
Sized to comfortably hold a small rolling window of episodes worth of decoders
|
||||
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
|
||||
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
|
||||
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
|
||||
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
|
||||
to the constructor (``None`` restores the legacy unbounded behaviour).
|
||||
"""
|
||||
|
||||
|
||||
def _default_max_cache_size() -> int | None:
|
||||
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
|
||||
if raw is None:
|
||||
return DEFAULT_DECODER_CACHE_SIZE
|
||||
raw = raw.strip().lower()
|
||||
if raw in ("", "none", "unbounded", "-1"):
|
||||
return None
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
|
||||
) from e
|
||||
if value <= 0:
|
||||
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
|
||||
return value
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
|
||||
|
||||
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
|
||||
backing it. When the cache is full and a new path is requested, the
|
||||
least-recently-used entry is evicted and its file handle is closed. This
|
||||
bounds host-RAM growth when iterating over datasets with many distinct
|
||||
video files (otherwise each ``DataLoader`` worker pins every decoder it has
|
||||
ever opened until the process exits).
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of decoders to retain. ``None`` disables
|
||||
eviction and restores legacy unbounded behaviour. Defaults to the
|
||||
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
|
||||
:data:`DEFAULT_DECODER_CACHE_SIZE`.
|
||||
"""
|
||||
|
||||
_SENTINEL: ClassVar[object] = object()
|
||||
|
||||
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||
if max_size is VideoDecoderCache._SENTINEL:
|
||||
max_size = _default_max_cache_size()
|
||||
if max_size is not None and max_size <= 0:
|
||||
raise ValueError(f"max_size must be positive or None; got {max_size}")
|
||||
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
|
||||
def __contains__(self, video_path: object) -> bool:
|
||||
with self._lock:
|
||||
return str(video_path) in self._cache
|
||||
|
||||
def get_decoder(self, video_path: str):
|
||||
"""Get a cached decoder or create a new one."""
|
||||
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
@@ -227,22 +268,36 @@ class VideoDecoderCache:
|
||||
video_path = str(video_path)
|
||||
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
entry = self._cache.get(video_path)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
return self._cache[video_path][0]
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
# Evict LRU entries until we are back under the cap. We close
|
||||
# evicted file handles immediately; the associated ``VideoDecoder``
|
||||
# is released to the GC when its last reference goes away.
|
||||
if self.max_size is not None:
|
||||
while len(self._cache) > self.max_size:
|
||||
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||
with contextlib.suppress(Exception):
|
||||
evicted_handle.close()
|
||||
|
||||
return decoder
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache and close file handles."""
|
||||
"""Clear the cache and close all file handles."""
|
||||
with self._lock:
|
||||
for _, file_handle in self._cache.values():
|
||||
file_handle.close()
|
||||
with contextlib.suppress(Exception):
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
@@ -351,17 +406,17 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
if video_encoder is None:
|
||||
video_encoder = camera_encoder_defaults()
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -373,8 +428,7 @@ def encode_video_frames(
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get input frames
|
||||
suffix = ".png" if not isinstance(video_encoder, DepthEncoderConfig) else ".tiff"
|
||||
template = "frame-" + ("[0-9]" * 6) + suffix
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
@@ -384,7 +438,7 @@ def encode_video_frames(
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -420,6 +474,92 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Re-encode a video file using the given encoder configuration.
|
||||
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping re-encode.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
if log_level is not None:
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
try:
|
||||
with av.open(input_video_path, mode="r") as src:
|
||||
try:
|
||||
in_stream = src.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise ValueError(f"No video stream in {input_video_path}") from e
|
||||
|
||||
fps = (
|
||||
in_stream.base_rate
|
||||
) # We allow fractional fps though LeRobotDataset only supports integer fps
|
||||
width = int(in_stream.width)
|
||||
height = int(in_stream.height)
|
||||
|
||||
with av.open(
|
||||
tmp_output_video_path,
|
||||
mode="w",
|
||||
options={
|
||||
"movflags": "faststart"
|
||||
}, # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
) as dst:
|
||||
out_stream = dst.add_stream(vcodec, fps, options=video_options)
|
||||
out_stream.pix_fmt = pix_fmt
|
||||
out_stream.width = width
|
||||
out_stream.height = height
|
||||
|
||||
for frame in src.decode(in_stream):
|
||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||
packet = out_stream.encode(frame)
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
packet = out_stream.encode()
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
except Exception:
|
||||
Path(tmp_output_video_path).unlink(missing_ok=True)
|
||||
raise
|
||||
finally:
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_video_path.exists():
|
||||
raise OSError(f"Video re-encoding did not work. File not found: {output_video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str],
|
||||
output_video_path: Path,
|
||||
@@ -536,21 +676,22 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.video_encoder = video_encoder
|
||||
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -575,12 +716,12 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
# Ensure HWC uint8 numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||
if frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
@@ -588,29 +729,15 @@ class _CameraEncoderThread(threading.Thread):
|
||||
height, width = frame_data.shape[:2]
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(
|
||||
self.video_encoder.vcodec,
|
||||
self.fps,
|
||||
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
|
||||
)
|
||||
output_stream.pix_fmt = self.video_encoder.pix_fmt
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
if not self.is_depth:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
else:
|
||||
video_frame = quantize_depth(
|
||||
frame_data,
|
||||
depth_min=self.video_encoder.depth_min,
|
||||
depth_max=self.video_encoder.depth_max,
|
||||
shift=self.video_encoder.shift,
|
||||
use_log=self.video_encoder.use_log,
|
||||
video_backend=self.video_encoder.video_backend,
|
||||
)
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
@@ -669,7 +796,6 @@ class StreamingVideoEncoder:
|
||||
self,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
@@ -685,7 +811,6 @@ class StreamingVideoEncoder:
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
|
||||
@@ -698,25 +823,18 @@ class StreamingVideoEncoder:
|
||||
self._episode_active = False
|
||||
self._closed = False
|
||||
|
||||
def start_episode(
|
||||
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
|
||||
) -> None:
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
depth_video_keys: List of video feature keys that carry depth maps (e.g.
|
||||
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
if depth_video_keys is None:
|
||||
depth_video_keys = []
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
@@ -725,15 +843,17 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
video_encoder=encoder,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self._encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -940,13 +1060,13 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
video_encoder: If provided, record the exact encoder settings used to encode this
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
@@ -966,10 +1086,13 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
||||
video_info["video.channels"] = pixel_channels
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
@@ -978,18 +1101,27 @@ def get_video_info(
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if video_encoder is not None:
|
||||
for field_name, field_value in asdict(video_encoder).items():
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
|
||||
@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
_placo_runtime_error: ImportError | None = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
try:
|
||||
import placo # type: ignore[import-not-found]
|
||||
except ImportError as _placo_import_err:
|
||||
placo = None
|
||||
_placo_runtime_error = _placo_import_err
|
||||
|
||||
|
||||
def _raise_if_placo_unusable() -> None:
|
||||
if placo is None and _placo_runtime_error is not None:
|
||||
raise ImportError(
|
||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
||||
) from _placo_runtime_error
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
@@ -44,6 +57,7 @@ class RobotKinematics:
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
require_package("placo", extra="placo-dep")
|
||||
_raise_if_placo_unusable()
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -43,6 +43,7 @@ from .tables import (
|
||||
CAN_CMD_SET_ZERO,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
HANDSHAKE_TIMEOUT_S,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
||||
def _query_status_via_clear_fault(
|
||||
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
||||
) -> tuple[bool, can.Message | None]:
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
||||
|
||||
def _recv_status_via_clear_fault(
|
||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
faulted_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
elif has_fault:
|
||||
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
|
||||
return responses
|
||||
|
||||
def _recv_all_messages_until_quiet(
|
||||
self,
|
||||
*,
|
||||
timeout: float = RUNNING_TIMEOUT,
|
||||
max_messages: int = 4096,
|
||||
) -> list[can.Message]:
|
||||
"""
|
||||
Receive frames until the bus goes quiet.
|
||||
|
||||
Args:
|
||||
timeout: Poll timeout used for each recv() call. Collection stops
|
||||
when one recv() times out (quiet gap).
|
||||
max_messages: Safety cap to prevent unbounded loops.
|
||||
"""
|
||||
out: list[can.Message] = []
|
||||
max_messages = max(1, max_messages)
|
||||
timeout = max(0.0, timeout)
|
||||
|
||||
try:
|
||||
while len(out) < max_messages:
|
||||
msg = self._bus().recv(timeout=timeout)
|
||||
if msg is None:
|
||||
break
|
||||
out.append(msg)
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
||||
|
||||
return out
|
||||
|
||||
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
||||
"""
|
||||
Decode all received feedback frames and update cached motor states.
|
||||
|
||||
Returns:
|
||||
Set of payload recv_ids that were successfully mapped to motors.
|
||||
"""
|
||||
processed_recv_ids: set[int] = set()
|
||||
for msg in messages:
|
||||
if len(msg.data) < 1:
|
||||
logger.debug(
|
||||
f"Dropping short CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
recv_id = int(msg.data[0])
|
||||
motor_name = self._recv_id_to_motor.get(recv_id)
|
||||
if motor_name is None:
|
||||
logger.debug(
|
||||
f"Unmapped CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
self._process_response(motor_name, msg)
|
||||
processed_recv_ids.add(recv_id)
|
||||
|
||||
return processed_recv_ids
|
||||
|
||||
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
||||
"""
|
||||
Drain pending RX frames from the CAN interface.
|
||||
|
||||
This is used by higher-level controllers to drop stale feedback before issuing
|
||||
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
||||
It should also be called once when a controller instance is created/connected,
|
||||
to clear residual frames left on the interface from previous sessions.
|
||||
"""
|
||||
drained = 0
|
||||
poll_timeout_s = max(0.0, poll_timeout_s)
|
||||
max_messages = max(1, max_messages)
|
||||
try:
|
||||
while drained < max_messages:
|
||||
msg = self._bus().recv(timeout=poll_timeout_s)
|
||||
if msg is None:
|
||||
break
|
||||
drained += 1
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
||||
return drained
|
||||
|
||||
def _speed_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
# Read every feedback frame until RX goes quiet, then decode all of them.
|
||||
# This avoids dropping useful frames when responses from different motors interleave.
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
try:
|
||||
self._decode_motor_state(msg.data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to decode response from {motor} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
||||
)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the state cache."""
|
||||
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._bus().send(msg)
|
||||
updated_motors.append(motor)
|
||||
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
||||
|
||||
for response in responses.values():
|
||||
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
||||
if payload_motor_name is not None:
|
||||
self._process_response(payload_motor_name, response)
|
||||
else:
|
||||
# Fallback: still attempt to decode based on payload byte0 mapping.
|
||||
self._decode_motor_state(response.data)
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
for motor in updated_motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if recv_id not in responses:
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
|
||||
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
RUNNING_TIMEOUT = 0.003
|
||||
HANDSHAKE_TIMEOUT_S = 0.05
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
|
||||
@@ -104,6 +104,8 @@ class AdamWConfig(OptimizerConfig):
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
foreach: bool | None = None
|
||||
fused: bool | None = None
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
|
||||
@@ -24,6 +24,7 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pi052.configuration_pi052 import PI052Config as PI052Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
@@ -47,6 +48,7 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"PI052Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -61,6 +61,79 @@ from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
|
||||
|
||||
def _restore_pi052_pretrained_state(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
pretrained_path: str,
|
||||
) -> None:
|
||||
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
|
||||
|
||||
pi052's preprocessor includes steps whose constructor args don't
|
||||
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
|
||||
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
|
||||
fitted-tokenizer path that may not exist at eval time). We rebuild
|
||||
those pipelines fresh from ``config.recipe_path`` and then walk
|
||||
over the saved ``policy_{pre,post}processor.json`` files to find
|
||||
each step's ``state_file`` reference and load the bytes back into
|
||||
the corresponding fresh step. Today that's only the
|
||||
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
|
||||
state quantile stats), but the loop is generic so any future
|
||||
stateful step picks up its blob automatically.
|
||||
|
||||
Pairing is by ``registry_name`` AND position so a benign reorder
|
||||
on the saved side surfaces a warning rather than silently feeding
|
||||
the wrong tensors into the wrong step.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
import logging # noqa: PLC0415
|
||||
from pathlib import Path # noqa: PLC0415
|
||||
|
||||
from safetensors.torch import load_file # noqa: PLC0415
|
||||
|
||||
base = Path(pretrained_path)
|
||||
if not base.exists():
|
||||
return
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
for pipeline, config_filename in [
|
||||
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
|
||||
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
|
||||
]:
|
||||
config_path = base / config_filename
|
||||
if not config_path.exists():
|
||||
continue
|
||||
saved = json.loads(config_path.read_text())
|
||||
|
||||
for idx, (saved_step, fresh_step) in enumerate(
|
||||
zip(saved.get("steps", []), pipeline.steps, strict=False)
|
||||
):
|
||||
state_file = saved_step.get("state_file")
|
||||
if not state_file:
|
||||
continue
|
||||
saved_name = saved_step.get("registry_name")
|
||||
fresh_name = getattr(type(fresh_step), "_registry_name", None)
|
||||
if saved_name and fresh_name and saved_name != fresh_name:
|
||||
log.warning(
|
||||
"PI052 state restore: %s step %d registry name mismatch "
|
||||
"(saved=%s, fresh=%s); skipping %s",
|
||||
config_filename, idx, saved_name, fresh_name, state_file,
|
||||
)
|
||||
continue
|
||||
state_path = base / state_file
|
||||
if not state_path.exists():
|
||||
log.warning(
|
||||
"PI052 state restore: %s missing at %s; %s left at fresh init",
|
||||
state_file, base, fresh_name,
|
||||
)
|
||||
continue
|
||||
fresh_step.load_state_dict(load_file(str(state_path)))
|
||||
log.info(
|
||||
"PI052 state restore: loaded %s into %s (step %d)",
|
||||
state_file, fresh_name, idx,
|
||||
)
|
||||
|
||||
|
||||
def _reconnect_relative_absolute_steps(
|
||||
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
|
||||
) -> None:
|
||||
@@ -127,6 +200,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi052":
|
||||
from .pi052.modeling_pi052 import PI052Policy
|
||||
|
||||
return PI052Policy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
@@ -167,8 +244,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
|
||||
"pi052", "gaussian_actor", "smolvla", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -191,6 +268,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi052":
|
||||
from .pi052.configuration_pi052 import PI052Config
|
||||
|
||||
return PI052Config(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -231,6 +312,12 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
# Optional: HF Hub repo id of the dataset the policy is being
|
||||
# trained on. Used by policies that auto-fit pieces of their
|
||||
# preprocessing (e.g. pi052's FAST action tokenizer per
|
||||
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
|
||||
# policies fall back to their universal pre-fitted tokenizers.
|
||||
dataset_repo_id: str | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -263,6 +350,29 @@ def make_pre_post_processors(
|
||||
NotImplementedError: If a processor factory is not implemented for the given
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
|
||||
# pi052 pipelines don't roundtrip through the saved
|
||||
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
|
||||
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
|
||||
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
|
||||
# FAST tokenizer path. Generic ``from_pretrained`` then dies
|
||||
# with ``RenderMessagesStep.__init__() missing 1 required
|
||||
# positional argument: 'recipe'`` (job 22164494).
|
||||
#
|
||||
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
|
||||
# fresh from ``config.recipe_path`` and transplant the saved
|
||||
# stateful blobs (normalizer stats) from the checkpoint dir.
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
preprocessor, postprocessor = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
@@ -357,6 +467,22 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi052":
|
||||
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
|
||||
# come before the PI05Config isinstance check below (otherwise
|
||||
# pi052 would silently pick up π0.5's processor).
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
processors = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
|
||||
# enabled — the train loop sets it from ``--dataset.repo_id``.
|
||||
# When ``None``, ``make_pi052_pre_post_processors`` skips
|
||||
# the auto-fit and uses the universal tokenizer.
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from .pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -176,16 +181,16 @@ N_COLOR_CHANNELS = 3
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -258,15 +265,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -274,13 +282,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -93,6 +93,21 @@ class PI05Config(PreTrainedConfig):
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
optimizer_foreach: bool | None = False
|
||||
optimizer_fused: bool | None = True
|
||||
|
||||
# LM-head LR multiplier. The PaliGemma `lm_head` projection (and its
|
||||
# tied `embed_tokens`) is the surface the LM head's first-token
|
||||
# distribution depends on. With ``knowledge_insulation`` blocking
|
||||
# action→VLM gradients, the LM head only sees gradients on text-CE
|
||||
# samples — which can be a small fraction of the mix (e.g. ~45% in
|
||||
# ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's
|
||||
# first-token distribution can drift back toward PaliGemma's
|
||||
# pretrained ``<loc>`` detection prior, despite teacher-forced CE
|
||||
# staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps
|
||||
# the head pinned to fine-tuning targets without perturbing the
|
||||
# backbone / vision tower / action expert. Default 1.0 = no change.
|
||||
lm_head_lr_scale: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
@@ -152,6 +167,8 @@ class PI05Config(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
foreach=self.optimizer_foreach,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -223,6 +223,42 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
return padded_images
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
|
||||
``torch.nn.functional.scaled_dot_product_attention``.
|
||||
|
||||
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
|
||||
bias masks (the FA backend only accepts causal/sliding-window). On
|
||||
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
|
||||
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
|
||||
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
|
||||
"""
|
||||
n_rep = module.num_key_value_groups
|
||||
if n_rep > 1:
|
||||
key = key.repeat_interleave(n_rep, dim=1)
|
||||
value = value.repeat_interleave(n_rep, dim=1)
|
||||
if attention_mask is not None and attention_mask.dtype != query.dtype:
|
||||
attention_mask = attention_mask.to(dtype=query.dtype)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=dropout if module.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
return attn_output.transpose(1, 2).contiguous(), None
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
@@ -261,8 +297,7 @@ def compute_layer_complete(
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
att_output, _ = sdpa_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
@@ -409,6 +444,7 @@ class PaliGemmaWithExpertModel(
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"lm_head",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -441,13 +477,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -617,10 +653,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
if dtype is not None:
|
||||
result = result.to(dtype=dtype)
|
||||
return result
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
@@ -662,7 +701,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
return lang_emb
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -749,21 +789,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
# Selective AC: rely on the per-layer checkpoint inside
|
||||
# ``PaliGemmaWithExpertModel.forward`` (which wraps each
|
||||
# transformer block individually). The previous outer
|
||||
# ``_apply_checkpoint(forward_func, ...)`` doubled up — it
|
||||
# re-ran the full backbone forward during backward *and* each
|
||||
# block's own checkpoint re-ran during that recompute. Pure
|
||||
# waste with SDPA, which already streams attention activations.
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
@@ -807,7 +848,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
|
||||
prefix_att_2d_masks, dtype=prefix_embs.dtype
|
||||
)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
@@ -877,7 +920,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(
|
||||
full_att_2d_masks, dtype=suffix_embs.dtype
|
||||
)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
@@ -1015,6 +1060,16 @@ class PI05Policy(PreTrainedPolicy):
|
||||
if remap_count > 0:
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
embed_tokens_key = (
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
)
|
||||
if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict:
|
||||
remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float()
|
||||
print("Initialized PaliGemma lm_head from language token embeddings")
|
||||
elif lm_head_key in remapped_state_dict:
|
||||
remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float()
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
@@ -1108,8 +1163,62 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
def get_optim_params(self):
|
||||
"""Return policy parameters, optionally split into LR-scaled groups.
|
||||
|
||||
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
|
||||
and its tied ``embed_tokens`` are placed in their own param
|
||||
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
|
||||
scheduler multiplies both groups by the same lambda each step,
|
||||
so the ratio is preserved across decay. Default ``1.0`` =
|
||||
return ``self.parameters()`` (back-compat with existing checkpoints
|
||||
and configs).
|
||||
"""
|
||||
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
|
||||
if scale == 1.0:
|
||||
return self.parameters()
|
||||
head_params: list[torch.nn.Parameter] = []
|
||||
other_params: list[torch.nn.Parameter] = []
|
||||
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
|
||||
# boosting only the projection without the embedding pulls them
|
||||
# apart and breaks the tie that PaliGemma was pre-trained with.
|
||||
head_substrings = (
|
||||
"paligemma_with_expert.paligemma.lm_head.",
|
||||
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
|
||||
)
|
||||
for name, p in self.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if any(s in name for s in head_substrings):
|
||||
head_params.append(p)
|
||||
else:
|
||||
other_params.append(p)
|
||||
base_lr = float(self.config.optimizer_lr)
|
||||
groups: list[dict[str, object]] = []
|
||||
if other_params:
|
||||
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
|
||||
if head_params:
|
||||
groups.append(
|
||||
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
|
||||
)
|
||||
# Sanity: head_substrings must match at least one parameter, otherwise
|
||||
# the scale silently does nothing — surface that fast.
|
||||
if not head_params:
|
||||
raise RuntimeError(
|
||||
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
|
||||
"name patterns: "
|
||||
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
|
||||
)
|
||||
logging.info(
|
||||
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
|
||||
"%d head params + %d other params",
|
||||
scale,
|
||||
base_lr,
|
||||
base_lr * scale,
|
||||
len(head_params),
|
||||
len(other_params),
|
||||
)
|
||||
return groups
|
||||
|
||||
def reset(self):
|
||||
"""Reset internal state - called when environment resets."""
|
||||
|
||||
42
src/lerobot/policies/pi052/__init__.py
Normal file
42
src/lerobot/policies/pi052/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
|
||||
inference recipe on lerobot.
|
||||
|
||||
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
|
||||
|
||||
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
|
||||
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
|
||||
(the "high-level subtask prediction" of the paper, §IV.D),
|
||||
* AR text generation at inference (:meth:`PI052Policy.select_message`),
|
||||
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
|
||||
text head against missing context at inference.
|
||||
|
||||
See ``src/lerobot/configs/recipes/subtasks_vqa.yaml`` for the
|
||||
canonical training recipe and
|
||||
``examples/training/pi052_hirobot.slurm`` for the launcher.
|
||||
"""
|
||||
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .modeling_pi052 import PI052Policy
|
||||
from .processor_pi052 import make_pi052_pre_post_processors
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
__all__ = [
|
||||
"PI052Config",
|
||||
"PI052Policy",
|
||||
"PI052TextTokenizerStep",
|
||||
"make_pi052_pre_post_processors",
|
||||
]
|
||||
216
src/lerobot/policies/pi052/configuration_pi052.py
Normal file
216
src/lerobot/policies/pi052/configuration_pi052.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
|
||||
hierarchical inference recipe.
|
||||
|
||||
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
|
||||
~300M Gemma action expert, joint training with FAST tokens during
|
||||
pre-train and flow matching during post-train), but with the
|
||||
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
|
||||
to predict both:
|
||||
|
||||
* **subtask strings** at the high level (cross-entropy on the LM
|
||||
head), and
|
||||
* **action chunks** at the low level (flow matching on the
|
||||
action-expert tokens).
|
||||
|
||||
This is the dual-head co-training pattern from the paper:
|
||||
|
||||
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, ℓ)‖²
|
||||
|
||||
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
|
||||
inference into a text-prediction step followed by an action-prediction
|
||||
step, which the multi-rate ``PI052Runtime`` (in
|
||||
``lerobot.policies.pi052.inference``) drives at separate rates.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
|
||||
from ..pi05.configuration_pi05 import PI05Config
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi052")
|
||||
@dataclass
|
||||
class PI052Config(PI05Config):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
|
||||
|
||||
Recipe-driven dual-head training: the flow head supervises actions,
|
||||
the LM head supervises subtask / plan / memory / VQA text. The
|
||||
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
|
||||
shipped alongside this policy. Set to ``None`` to disable recipe
|
||||
rendering and fall back to π0.5's single-task ``Task: ... Action:``
|
||||
prompt path (unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = False
|
||||
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
|
||||
chat template, so we don't apply one. The recipe renderer's output
|
||||
is concatenated as a plain prefix + assistant suffix instead,
|
||||
mirroring how the π0.5 paper's high-level inference samples text
|
||||
auto-regressively after the prefix."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
# Paper §IV.D uses α=10 between the flow and text terms, assuming
|
||||
# text is a rare auxiliary task. With the recipe stack the flow-only
|
||||
# `low_level` branch fires on a large share of samples, so α=10
|
||||
# swamps the LM head and collapses generation into degenerate
|
||||
# repetition. We use the milder 5:1 split here.
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / π0.5 behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 5.0
|
||||
"""Weight on the action-expert flow-matching term. ``5.0`` — a milder
|
||||
flow:text split than the paper's α=10, since the flow-only
|
||||
``low_level`` recipe already gives the action expert frequent
|
||||
gradient. Lower it further if the LM head still underfits."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
|
||||
The existing ``PI05Policy`` zeroes / freezes the head on load
|
||||
because it never reads from it. Must be ``True`` for π0.5-style
|
||||
hierarchical inference."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# Randomly drop non-target context messages so the LM head learns
|
||||
# to handle missing /
|
||||
# stale plan / memory at inference. Defaults to 0.0 so behaviour
|
||||
# is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
|
||||
# FAST discrete-action supervision — paper §III.B-C ------------------
|
||||
# When enabled, actions are *also* tokenised via the FAST tokenizer
|
||||
# ("physical-intelligence/fast") and supervised with cross-entropy
|
||||
# on the PaliGemma LM head — exactly as in the paper's pre-training
|
||||
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
|
||||
# ActionTokenizerProcessorStep is wired into the preprocessor
|
||||
# pipeline when this flag is set; the loss is computed in
|
||||
# PI052Policy.forward.
|
||||
enable_fast_action_loss: bool = True
|
||||
"""If True, tokenise actions with the FAST tokenizer and add a
|
||||
cross-entropy loss on the LM head. On by default to match the
|
||||
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
|
||||
§III.B-C Eq. 1). Set to False if you only want the
|
||||
post-training-style flow + text recipe."""
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
"""HF identifier for the FAST action tokenizer."""
|
||||
|
||||
max_action_tokens: int = 256
|
||||
"""Maximum number of FAST tokens per action chunk."""
|
||||
|
||||
fast_skip_tokens: int = 128
|
||||
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
|
||||
collisions with PaliGemma's text vocabulary."""
|
||||
|
||||
fast_action_loss_weight: float = 1.0
|
||||
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||
|
||||
auto_fit_fast_tokenizer: bool = False
|
||||
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
|
||||
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
|
||||
base_tokenizer_name, fit_samples)``. On cache miss, it loads
|
||||
``action_tokenizer_name`` as a base, samples
|
||||
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
|
||||
``.fit()``, saves the result, and uses *that* fitted path as the
|
||||
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
|
||||
explicitly recommend per-dataset fitting for best compression.
|
||||
|
||||
Off by default because the fit requires a separate pre-training
|
||||
pass over the dataset (~1-2 min on a medium dataset) and depends
|
||||
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
|
||||
when you want paper-faithful compression; leave off to fall back
|
||||
on the universal ``physical-intelligence/fast`` codebook."""
|
||||
|
||||
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
|
||||
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
|
||||
|
||||
fast_tokenizer_fit_samples: int = 1024
|
||||
"""Number of action chunks to sample for the fit. The FAST paper uses
|
||||
a few thousand; 1024 is a reasonable default for medium datasets."""
|
||||
|
||||
# Knowledge insulation — paper §III.B --------------------------------
|
||||
# When enabled, gradients from the action expert's flow loss are
|
||||
# blocked from flowing back into the VLM's K/V projections. This
|
||||
# prevents the action loss from over-fitting the language backbone
|
||||
# to robot-specific features. Implemented in ``modeling_pi052`` as
|
||||
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
|
||||
# that splits queries into VLM and action halves and ``.detach()``-s
|
||||
# the VLM K/V tensors used in the action-half's attention.
|
||||
knowledge_insulation: bool = False
|
||||
"""If True, route every transformer layer through the KI
|
||||
attention path that blocks action→VLM gradient flow on K/V."""
|
||||
|
||||
# Learning-rate defaults --------------------------------------------
|
||||
# pi052 inherits π0.5's openpi-validated optimizer config (peak LR
|
||||
# 2.5e-5, cosine→2.5e-6, 1k warmup, AdamW (0.9, 0.95), wd=0.01,
|
||||
# grad_clip=1.0). The only place pi052 needs to diverge from pi05
|
||||
# is the LM-head LR multiplier: pi05 has no text supervision so the
|
||||
# head doesn't get gradients; pi052 always has text supervision
|
||||
# (subtask / memory / VQA) via the recipe, and under KI the LM head
|
||||
# only sees gradients on ~30–45% of the batch (the text-CE mask
|
||||
# share of the recipe). Under aggressive cosine decay this is too
|
||||
# weak to keep the head pinned, so it drifts back toward PaliGemma's
|
||||
# pretrained ``<loc>`` first-token bias. 5x is the documented fix
|
||||
# (see ``PI05Config.lm_head_lr_scale`` docstring); the wiring is
|
||||
# already in ``PI05Policy.get_optim_params`` — it splits the LM head
|
||||
# + tied ``embed_tokens`` into their own param group while sharing
|
||||
# the same cosine lambda, so the 5x ratio is preserved across decay.
|
||||
lm_head_lr_scale: float = 5.0
|
||||
|
||||
# PaLM-style z-loss on text CE. Penalises the log-partition function
|
||||
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
|
||||
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
|
||||
# while CE stays low, because a uniform additive logit bias cancels in
|
||||
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
|
||||
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
|
||||
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
|
||||
# commonly cited weight; set 0 to disable entirely.
|
||||
text_ce_z_loss_weight: float = 1e-4
|
||||
|
||||
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
|
||||
# unconditionally at model build time — see ``_enable_hf_kernels``
|
||||
# in ``modeling_pi052``. The patch is process-global, idempotent
|
||||
# and degrades gracefully if ``liger-kernel`` is missing. Measured
|
||||
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||
use_hf_kernels: bool = True
|
||||
"""Deprecated. Liger HF kernels are patched unconditionally by
|
||||
``_enable_hf_kernels`` — this field is retained as a no-op for
|
||||
backward compatibility with checkpoints saved before commit
|
||||
d70c8104 (which still serialize ``use_hf_kernels: true`` into
|
||||
``config.json``). Loading those configs would otherwise raise
|
||||
``DecodingError: The fields use_hf_kernels are not valid for
|
||||
PI052Config`` (job 22164492). Remove in a future major bump."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
# we're training it. Override the π0.5 default
|
||||
# (``train_expert_only=True``) unless the user explicitly opts
|
||||
# out of text training via ``text_loss_weight=0``.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
self.train_expert_only = False
|
||||
263
src/lerobot/policies/pi052/fit_fast_tokenizer.py
Normal file
263
src/lerobot/policies/pi052/fit_fast_tokenizer.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset-specific FAST action tokenizer fitting.
|
||||
|
||||
The published ``physical-intelligence/fast`` tokenizer is a *universal*
|
||||
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
|
||||
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
|
||||
π0.5 itself, the recommended practice is to **finetune the tokenizer on
|
||||
your specific dataset's action distribution** before training the
|
||||
policy — same way one would adapt a language tokenizer to a domain
|
||||
corpus. Without this finetune step, action sequences from your robot
|
||||
may require more tokens per chunk than necessary, lowering effective
|
||||
compression and slowing convergence of the action-CE loss.
|
||||
|
||||
This module provides a single utility, :func:`fit_fast_tokenizer`,
|
||||
that does the finetune. The training entry point invokes it
|
||||
automatically when the policy's ``enable_fast_action_loss`` and
|
||||
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
|
||||
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
|
||||
|
||||
The fitted tokenizer is saved to
|
||||
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
|
||||
runs over the same dataset re-use it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
|
||||
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
|
||||
# that's the image / feature-extractor convention). Centralised here so
|
||||
# the cache-hit check and the rank-N readiness wait agree on the same
|
||||
# sentinel.
|
||||
_CACHE_SENTINEL = "processor_config.json"
|
||||
|
||||
|
||||
def _dataset_signature(
|
||||
dataset_repo_id: str,
|
||||
base_tokenizer_name: str,
|
||||
n_samples: int,
|
||||
chunk_size: int,
|
||||
) -> str:
|
||||
"""Deterministic short hash for naming the cache directory.
|
||||
|
||||
Keys on (dataset, base tokenizer, sample count, chunk size) so any
|
||||
of those changing re-runs the fit. ``chunk_size`` matters because
|
||||
the tokenizer is fit on chunks of that length.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
h.update(dataset_repo_id.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(base_tokenizer_name.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(n_samples).encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(chunk_size).encode("utf-8"))
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def fit_fast_tokenizer(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
cache_dir: str | Path,
|
||||
base_tokenizer_name: str = "physical-intelligence/fast",
|
||||
n_samples: int = 1024,
|
||||
chunk_size: int = 50,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
|
||||
cache_dir: Directory under which to save (and look up) fitted
|
||||
tokenizers. The actual save path is
|
||||
``{cache_dir}/{signature}``.
|
||||
base_tokenizer_name: HF identifier for the base FAST tokenizer
|
||||
to finetune from. ``physical-intelligence/fast`` is the
|
||||
universal one.
|
||||
n_samples: Number of action chunks to sample for the fit. The
|
||||
FAST paper uses a few thousand; ``1024`` is a good default
|
||||
for medium datasets.
|
||||
chunk_size: Length of each action chunk (matches
|
||||
``policy.chunk_size``). The FAST tokenizer is fit on
|
||||
sequences of this length.
|
||||
seed: RNG seed for sample selection.
|
||||
|
||||
Returns:
|
||||
The local path to the fitted tokenizer. Passed directly to
|
||||
``--policy.action_tokenizer_name`` for the training run.
|
||||
|
||||
Raises:
|
||||
ImportError: If the ``transformers`` library doesn't expose
|
||||
``AutoProcessor`` or the FAST tokenizer doesn't have a
|
||||
``.fit()`` method (then you're on an older FAST snapshot —
|
||||
update to the current published model).
|
||||
FileNotFoundError: If the dataset can't be loaded.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
|
||||
out_dir = cache_dir / sig
|
||||
|
||||
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
|
||||
logger.info(
|
||||
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
|
||||
"dataset=%s base=%s n_samples=%d",
|
||||
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
|
||||
)
|
||||
return str(out_dir)
|
||||
|
||||
# DDP-safe fit: only the (local) main process actually fits + saves;
|
||||
# other ranks poll the cache sentinel until the leader is done.
|
||||
# Without this guard, all N ranks fit concurrently and race on
|
||||
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
|
||||
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
|
||||
# and compiles a ``.pyc`` — concurrent writers occasionally produce
|
||||
# a stale / partial ``.pyc`` and the subsequent ``from .. import
|
||||
# UniversalActionProcessor`` raises ``AttributeError``.
|
||||
is_leader = (
|
||||
int(os.environ.get("RANK", "0")) == 0
|
||||
and int(os.environ.get("LOCAL_RANK", "0")) == 0
|
||||
)
|
||||
if not is_leader:
|
||||
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
|
||||
start = time.monotonic()
|
||||
while not (out_dir / _CACHE_SENTINEL).exists():
|
||||
if time.monotonic() - start > timeout_s:
|
||||
raise RuntimeError(
|
||||
f"FAST tokenizer fit: non-leader rank timed out after "
|
||||
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
|
||||
"Leader rank likely crashed during the fit."
|
||||
)
|
||||
time.sleep(2.0)
|
||||
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
|
||||
return str(out_dir)
|
||||
|
||||
logger.info(
|
||||
"FAST tokenizer cache miss — fitting on dataset=%s "
|
||||
"base=%s n_samples=%d chunk_size=%d → %s",
|
||||
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
# Stream a single episode's worth of action chunks at a time so
|
||||
# we don't blow memory on huge datasets. Random episode +
|
||||
# random start offset gives a reasonable spread.
|
||||
#
|
||||
# Actions are read straight from the underlying HF dataset's
|
||||
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
|
||||
# training item (delta-timestamp expansion + video decode + image
|
||||
# transforms); a single bad video frame would then throw and, since
|
||||
# the failure was swallowed at debug level, silently starve the fit
|
||||
# of every chunk. The action column carries no video, so reading it
|
||||
# directly is both faster and immune to decode errors.
|
||||
rng = np.random.default_rng(seed)
|
||||
actions_buf: list[np.ndarray] = []
|
||||
|
||||
# Load just the metadata first to know episode boundaries.
|
||||
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
num_episodes = ds_meta_only.meta.total_episodes
|
||||
if "action" not in ds_meta_only.features:
|
||||
available = ", ".join(sorted(ds_meta_only.features)) or "<none>"
|
||||
raise RuntimeError(
|
||||
f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. "
|
||||
f"Available features: {available}."
|
||||
)
|
||||
del ds_meta_only
|
||||
|
||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||
collected = 0
|
||||
eps_visited = 0
|
||||
short_episodes = 0
|
||||
for ep_idx in rng.permutation(num_episodes):
|
||||
if collected >= n_samples:
|
||||
break
|
||||
ep_idx = int(ep_idx)
|
||||
try:
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
|
||||
ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
|
||||
continue
|
||||
if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size:
|
||||
short_episodes += 1
|
||||
continue
|
||||
# Sample ``samples_per_episode`` contiguous chunks uniformly.
|
||||
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
|
||||
for s in starts:
|
||||
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
|
||||
collected += 1
|
||||
if collected >= n_samples:
|
||||
break
|
||||
eps_visited += 1
|
||||
|
||||
if not actions_buf:
|
||||
raise RuntimeError(
|
||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
|
||||
f"all {num_episodes} episodes were shorter than chunk_size="
|
||||
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
|
||||
"``action`` column. Lower ``chunk_size`` to match your episode "
|
||||
"lengths."
|
||||
)
|
||||
|
||||
actions = np.stack(actions_buf, axis=0).astype(np.float32) # (N, H, D)
|
||||
logger.info(
|
||||
"FAST fit: collected %d chunks of shape %s from %d episodes",
|
||||
actions.shape[0], actions.shape[1:], eps_visited,
|
||||
)
|
||||
|
||||
# Quantile-normalise per dimension before fitting.
|
||||
#
|
||||
# The FAST tokenizer DCT-transforms actions, scales by ``scale`` and
|
||||
# rounds to integer tokens; the integer *range* must fit the
|
||||
# codebook (vocab_size, default 1024). Raw motor units (e.g. encoder
|
||||
# ticks) blow that range up — hence "Vocab size 1024 is too small".
|
||||
# More importantly, at training time ``ActionTokenizerProcessorStep``
|
||||
# runs *after* the QUANTILES ``NormalizerProcessorStep``, so it
|
||||
# encodes normalised actions. Fitting on raw actions would mismatch
|
||||
# that space. We replicate QUANTILES normalisation here (per-dim
|
||||
# [q01, q99] → [-1, 1], clipped) so the fit and the training-time
|
||||
# encode see the same distribution.
|
||||
flat = actions.reshape(-1, actions.shape[-1])
|
||||
q01 = np.quantile(flat, 0.01, axis=0)
|
||||
q99 = np.quantile(flat, 0.99, axis=0)
|
||||
span = np.where((q99 - q01) > 1e-6, q99 - q01, 1.0)
|
||||
actions = np.clip((actions - q01) / span * 2.0 - 1.0, -1.0, 1.0).astype(np.float32)
|
||||
|
||||
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
|
||||
if not hasattr(base, "fit"):
|
||||
raise ImportError(
|
||||
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
|
||||
"method — your transformers / model snapshot is too old. Update "
|
||||
"to the current ``physical-intelligence/fast`` revision."
|
||||
)
|
||||
|
||||
fitted = base.fit(actions)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fitted.save_pretrained(str(out_dir))
|
||||
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
|
||||
return str(out_dir)
|
||||
73
src/lerobot/policies/pi052/inference/__init__.py
Normal file
73
src/lerobot/policies/pi052/inference/__init__.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 inference / runtime orchestration.
|
||||
|
||||
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
|
||||
low_level_execution → LowLevelForward + DispatchAction (high Hz)
|
||||
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
|
||||
memory_update → MemoryUpdateFwd (event: subtask_change)
|
||||
user_interjection_response → UserInterjectionFwd (event: stdin)
|
||||
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||
|
||||
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
|
||||
``run()``.
|
||||
"""
|
||||
|
||||
from .repl import StdinReader
|
||||
from .runtime import PI052Runtime
|
||||
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||
from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
"PI052Runtime",
|
||||
"StdinReader",
|
||||
# state helpers
|
||||
"initial_runtime_state",
|
||||
"push_log",
|
||||
"set_if_changed",
|
||||
"take_event",
|
||||
# triggers
|
||||
"Trigger",
|
||||
"Tick",
|
||||
"TickClock",
|
||||
"HzTrigger",
|
||||
"EventTrigger",
|
||||
# steps
|
||||
"InferenceStep",
|
||||
"LowLevelForward",
|
||||
"DispatchAction",
|
||||
"HighLevelSubtaskFwd",
|
||||
"MemoryUpdateFwd",
|
||||
"UserInterjectionFwd",
|
||||
"AskVQAFwd",
|
||||
"DispatchToolCalls",
|
||||
# UI
|
||||
"make_state_panel",
|
||||
"print_robot_lines",
|
||||
"print_user_line",
|
||||
]
|
||||
105
src/lerobot/policies/pi052/inference/repl.py
Normal file
105
src/lerobot/policies/pi052/inference/repl.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stdin REPL event collector for the PI052 runtime.
|
||||
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
"stop" / "quit" / "exit" → state["stop"] = True
|
||||
"/action" / "/pause" → set state["mode"]
|
||||
ends with "?" → user_vqa_query event
|
||||
starts with "task:" or first line → set runtime task
|
||||
anything else → user_interjection event
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
|
||||
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
|
||||
directly in its REPL / autonomous loops and does *not* wire this
|
||||
collector; it's kept as the documented embedding hook and for tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StdinReader:
|
||||
"""Non-blocking stdin line collector for the runtime loop."""
|
||||
|
||||
prompt: str = "> "
|
||||
_seen_first_line: bool = field(default=False, init=False)
|
||||
_prompted: bool = field(default=False, init=False)
|
||||
|
||||
def poll(self, state: dict[str, Any]) -> None:
|
||||
"""Drain pending stdin lines into runtime events."""
|
||||
# Print the input prompt once on every fresh tick if we don't
|
||||
# already have a pending line; matches the expected REPL feel.
|
||||
if not self._prompted:
|
||||
print(self.prompt, end="", flush=True)
|
||||
self._prompted = True
|
||||
|
||||
# ``select`` with timeout=0 makes this non-blocking. Only works
|
||||
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
|
||||
try:
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0)
|
||||
except (ValueError, OSError):
|
||||
return
|
||||
if not ready:
|
||||
return
|
||||
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
state["stop"] = True
|
||||
return
|
||||
line = line.strip()
|
||||
self._prompted = False # we'll re-prompt next tick
|
||||
if not line:
|
||||
return
|
||||
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
state["stop"] = True
|
||||
return
|
||||
|
||||
# Slash commands flip the run mode. ``/pause`` stops the action
|
||||
# loop (the action steps gate on ``state["mode"]``); ``/action``
|
||||
# resumes it.
|
||||
if lower.split(" ", 1)[0] in {"/action", "/act", "/run"}:
|
||||
state["mode"] = "action"
|
||||
return
|
||||
if lower in {"/pause", "/p"}:
|
||||
state["mode"] = "paused"
|
||||
queue = state.get("action_queue")
|
||||
if hasattr(queue, "clear"):
|
||||
queue.clear()
|
||||
return
|
||||
|
||||
# First non-control line sets the task if no task is active.
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
state["task"] = task
|
||||
print(f"[pi052] Task: {task}", flush=True)
|
||||
self._seen_first_line = True
|
||||
return
|
||||
|
||||
# Question → VQA; statement → interjection.
|
||||
if lower.endswith("?"):
|
||||
state["recent_vqa_query"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||
else:
|
||||
state["recent_interjection"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_interjection")
|
||||
205
src/lerobot/policies/pi052/inference/runtime.py
Normal file
205
src/lerobot/policies/pi052/inference/runtime.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 runtime loop.
|
||||
|
||||
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||
event collector, drives ticks through :class:`TickClock`, and prints
|
||||
state-change updates to the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .runtime_state import initial_runtime_state, push_log
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, TickClock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PI052Runtime:
|
||||
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||
|
||||
policy: Any
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
|
||||
from :func:`lerobot.tools.get_tools(meta)` when wiring the
|
||||
runtime."""
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
"""Closure returning the current preprocessed observation batch.
|
||||
``None`` for dry-run / language-only sessions."""
|
||||
robot_executor: Callable[[Any], None] | None = None
|
||||
"""Closure that takes one action chunk and forwards it to the
|
||||
robot. ``None`` for dry-run."""
|
||||
event_collector: Callable[[dict], None] | None = None
|
||||
"""Per-tick hook that polls external sources (stdin, network) and
|
||||
appends event names to ``state["events_this_tick"]``."""
|
||||
chunk_hz: float = 4.0
|
||||
ctrl_hz: float = 50.0
|
||||
high_level_hz: float = 1.0
|
||||
max_rate_hz: float = 50.0
|
||||
|
||||
pipeline: list[InferenceStep] = field(init=False)
|
||||
state: dict[str, Any] = field(init=False)
|
||||
_stop: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Subtask + memory + VQA configuration. Pipeline:
|
||||
#
|
||||
# HighLevelSubtaskFwd → generate the next subtask via the LM
|
||||
# head at ~``high_level_hz``; writes
|
||||
# ``current_subtask`` and emits
|
||||
# ``subtask_change`` on a transition.
|
||||
# MemoryUpdateFwd → on ``subtask_change``, refresh
|
||||
# ``current_memory`` from the
|
||||
# ``memory_update`` head.
|
||||
# AskVQAFwd → answer camera-grounded stdin questions.
|
||||
# LowLevelForward → action chunk conditioned on the
|
||||
# generated ``current_subtask``.
|
||||
# DispatchAction → drain the chunk to the robot.
|
||||
# DispatchToolCalls → fire any pending tool calls.
|
||||
#
|
||||
# Order matters: ``HighLevelSubtaskFwd`` must run before
|
||||
# ``MemoryUpdateFwd`` so the event is visible the same tick, and
|
||||
# both must run before ``LowLevelForward`` (which is gated on
|
||||
# "action queue empty") so the chunk consumes the freshest
|
||||
# subtask. ``UserInterjectionFwd`` is still importable but
|
||||
# disabled until plan generation is wired in.
|
||||
self.pipeline = [
|
||||
HighLevelSubtaskFwd(
|
||||
trigger=HzTrigger(self.high_level_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
# Listens for the ``subtask_change`` event raised by
|
||||
# ``HighLevelSubtaskFwd`` and refreshes ``current_memory``.
|
||||
MemoryUpdateFwd(
|
||||
trigger=EventTrigger("subtask_change"),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
AskVQAFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
LowLevelForward(
|
||||
trigger=HzTrigger(self.chunk_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
DispatchAction(
|
||||
trigger=HzTrigger(self.ctrl_hz),
|
||||
robot_executor=self.robot_executor,
|
||||
),
|
||||
DispatchToolCalls(tools=self.tools),
|
||||
]
|
||||
self.state = initial_runtime_state()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_task(self, task: str) -> None:
|
||||
"""Set or replace the active task. Logged for the REPL."""
|
||||
self.state["task"] = task
|
||||
push_log(self.state, f"Task: {task}")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop = True
|
||||
|
||||
def run(self, *, max_ticks: int | None = None) -> None:
|
||||
"""Main loop. Returns when ``stop()`` is called or after
|
||||
``max_ticks`` ticks (useful for tests / dry-run)."""
|
||||
clock = TickClock(max_rate_hz=self.max_rate_hz)
|
||||
while not self._stop:
|
||||
tick = clock.advance()
|
||||
self.state["_tick"] = tick
|
||||
self.state["events_this_tick"] = []
|
||||
self.state["log_lines"] = []
|
||||
|
||||
if self.event_collector is not None:
|
||||
self.event_collector(self.state)
|
||||
if self.state.get("stop"):
|
||||
self._stop = True
|
||||
break
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
self._flush_logs()
|
||||
if max_ticks is not None and tick.index >= max_ticks:
|
||||
break
|
||||
|
||||
self._on_shutdown()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# REPL helper: drive one full pipeline pass and return its logs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step_once(self) -> list[str]:
|
||||
"""Run one tick of the pipeline and return the log lines.
|
||||
|
||||
Used by the interactive REPL: instead of a background thread,
|
||||
the CLI drives ticks synchronously after each user input. Logs
|
||||
are returned (not printed) so the caller can route them into
|
||||
the rich-Live chat scrollback.
|
||||
"""
|
||||
from .triggers import Tick # noqa: PLC0415
|
||||
|
||||
# Synthesize a tick. We don't need the real wall-clock pacing
|
||||
# here — the REPL drives the runtime, not vice versa — but
|
||||
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
|
||||
# bump it generously so every Hz-triggered step considers
|
||||
# itself due.
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
|
||||
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
|
||||
self.state["log_lines"] = []
|
||||
# ``events_this_tick`` is set up by the caller before
|
||||
# ``step_once`` (the REPL pushes user-driven events first).
|
||||
self.state.setdefault("events_this_tick", [])
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
return list(self.state.get("log_lines") or [])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _flush_logs(self) -> None:
|
||||
for line in self.state.get("log_lines") or []:
|
||||
print(f"[pi052] {line}", flush=True)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
# Drain any queued action chunks safely.
|
||||
queue = self.state.get("action_queue")
|
||||
if isinstance(queue, deque):
|
||||
queue.clear()
|
||||
print("[pi052] runtime stopped", flush=True)
|
||||
95
src/lerobot/policies/pi052/inference/runtime_state.py
Normal file
95
src/lerobot/policies/pi052/inference/runtime_state.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Runtime state passed between inference steps each tick.
|
||||
|
||||
The runtime threads a single dict through the pipeline; this module
|
||||
documents the shape and provides factories. We use a plain ``dict``
|
||||
rather than a frozen dataclass because steps freely add and remove
|
||||
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
|
||||
…) and dataclass field churn would just get in the way.
|
||||
|
||||
Stable keys (read by multiple steps):
|
||||
|
||||
task str the current top-level task
|
||||
current_plan str | None latest plan emitted by the planner
|
||||
current_subtask str | None latest subtask the policy is executing
|
||||
current_memory str | None latest compressed memory
|
||||
recent_interjection str | None most recent user interjection text (consumed)
|
||||
|
||||
action_queue collections.deque[Tensor] pending action chunks
|
||||
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
|
||||
|
||||
events_this_tick list[str] triggers consumed this tick
|
||||
_tick Tick current tick (set by the loop)
|
||||
|
||||
mode str "action" (run the robot) | "paused"
|
||||
(action loop stopped — robot holds)
|
||||
|
||||
log_lines list[str] human-readable status lines printed each tick
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
|
||||
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||
"""Build a fresh runtime state dict with sensible defaults."""
|
||||
return {
|
||||
"task": task,
|
||||
"current_plan": None,
|
||||
"current_subtask": None,
|
||||
"current_memory": None,
|
||||
"recent_interjection": None,
|
||||
"action_queue": deque(),
|
||||
"tool_calls_pending": [],
|
||||
"events_this_tick": [],
|
||||
"log_lines": [],
|
||||
"mode": "action",
|
||||
"stop": False,
|
||||
}
|
||||
|
||||
|
||||
def take_event(state: dict[str, Any], event_name: str) -> bool:
|
||||
"""Pop ``event_name`` from ``events_this_tick`` if present.
|
||||
|
||||
Steps that consume an event call this so the same event doesn't
|
||||
re-fire on a sibling step within the same tick.
|
||||
"""
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
if event_name in events:
|
||||
events.remove(event_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def push_log(state: dict[str, Any], line: str) -> None:
|
||||
"""Append ``line`` to the per-tick log buffer; the runtime prints
|
||||
it at the end of the tick."""
|
||||
state.setdefault("log_lines", []).append(line)
|
||||
|
||||
|
||||
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
|
||||
"""Update ``state[key]`` and log a diff line if the value changed.
|
||||
|
||||
Returns ``True`` if the value actually changed.
|
||||
"""
|
||||
prev = state.get(key)
|
||||
if prev == value:
|
||||
return False
|
||||
state[key] = value
|
||||
if label is not None:
|
||||
push_log(state, f" {label}: {value}")
|
||||
return True
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user