mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Compare commits
163 Commits
codex/roll
...
feat/langu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e7c0d6aa1 | ||
|
|
920c6ef5a2 | ||
|
|
c37b1fc7d0 | ||
|
|
9020635b14 | ||
|
|
471b2b1b1d | ||
|
|
a15e16c072 | ||
|
|
336af85c09 | ||
|
|
54221ceea2 | ||
|
|
369ab17110 | ||
|
|
86a7edc590 | ||
|
|
8194897994 | ||
|
|
9f437d86b6 | ||
|
|
b74a551d38 | ||
|
|
c0a2e9814d | ||
|
|
bac4f61eae | ||
|
|
f4b834844e | ||
|
|
a0233f53f4 | ||
|
|
dfdc48a7f1 | ||
|
|
6a8878a639 | ||
|
|
d38eb89f71 | ||
|
|
7ab4936b1b | ||
|
|
2ea0da2d9f | ||
|
|
134a707c7a | ||
|
|
ca8c60a0ed | ||
|
|
ce47075d6b | ||
|
|
26013da699 | ||
|
|
3c15fd8537 | ||
|
|
f72b28738a | ||
|
|
1bd53cc7da | ||
|
|
7128bb1769 | ||
|
|
31e0c15e55 | ||
|
|
c5676ef1b3 | ||
|
|
5ebbdf3d05 | ||
|
|
9dfc9084e1 | ||
|
|
6e035fb169 | ||
|
|
fd18beb3a1 | ||
|
|
01dcb4c292 | ||
|
|
bd9619dfc3 | ||
|
|
0a4a7c40ad | ||
|
|
ca9028ad64 | ||
|
|
9db9c35cb4 | ||
|
|
fe96b28c74 | ||
|
|
2438df1307 | ||
|
|
f218d5ab30 | ||
|
|
04125492e4 | ||
|
|
e963e5a0c4 | ||
|
|
26ff40ddd7 | ||
|
|
6d269b28c8 | ||
|
|
b607c8458e | ||
|
|
9e83510c99 | ||
|
|
1f7b03f5f2 | ||
|
|
cb8edf17e6 | ||
|
|
5699f6cbf4 | ||
|
|
0e6114ac36 | ||
|
|
965d42825f | ||
|
|
1238a0cd47 | ||
|
|
53c7641885 | ||
|
|
088c8371df | ||
|
|
3a52a18b0e | ||
|
|
dad2cf1178 | ||
|
|
bce5387e04 | ||
|
|
c8ce413d73 | ||
|
|
82dffde7fa | ||
|
|
eaf0218bc8 | ||
|
|
a0e52d52fe | ||
|
|
85576acc29 | ||
|
|
e7e5fca5de | ||
|
|
beb22afd81 | ||
|
|
e99c55af4b | ||
|
|
408e0ca763 | ||
|
|
d55b581ca1 | ||
|
|
24d2ffe3c6 | ||
|
|
789f29aa56 | ||
|
|
a356b12c41 | ||
|
|
e8327b8e62 | ||
|
|
c450298147 | ||
|
|
5c30b14929 | ||
|
|
ce24063efd | ||
|
|
82934719db | ||
|
|
401a217597 | ||
|
|
40094b0464 | ||
|
|
8fa8323c91 | ||
|
|
fdbfc015a2 | ||
|
|
73740ecf4b | ||
|
|
1b81e49214 | ||
|
|
d813c75b76 | ||
|
|
3434d2ef22 | ||
|
|
b71e10da6b | ||
|
|
0f6e3230df | ||
|
|
2f2e42c4aa | ||
|
|
5ee0104739 | ||
|
|
e064cfcb04 | ||
|
|
b3d9494831 | ||
|
|
1217fdb6f0 | ||
|
|
d0388e1142 | ||
|
|
524aa59faa | ||
|
|
27f7829b09 | ||
|
|
7f8bf108e8 | ||
|
|
855ff027f8 | ||
|
|
3b797bb118 | ||
|
|
aea04721ae | ||
|
|
ab5479129a | ||
|
|
e6d4ac6f02 | ||
|
|
5722d365c5 | ||
|
|
3d7e60cee4 | ||
|
|
7b767d4d60 | ||
|
|
f1e3ab7794 | ||
|
|
585341ba9f | ||
|
|
23ff346027 | ||
|
|
3c5cbe7af4 | ||
|
|
f2cbd97635 | ||
|
|
c06c8d594a | ||
|
|
cd495a3a9d | ||
|
|
c99ac45cd1 | ||
|
|
13aaafeae0 | ||
|
|
2129648bf4 | ||
|
|
f5cd3f6e4e | ||
|
|
ecf5766301 | ||
|
|
11597d4f71 | ||
|
|
8b9c598cf4 | ||
|
|
b325475b38 | ||
|
|
ef137ff86a | ||
|
|
c5df821a96 | ||
|
|
7ec3d7999c | ||
|
|
712d63abbd | ||
|
|
6653999983 | ||
|
|
4bdbedc9a0 | ||
|
|
e240305e8e | ||
|
|
ccd189b264 | ||
|
|
ef1242bbd4 | ||
|
|
ebf4a04d41 | ||
|
|
4419b4ef1b | ||
|
|
ff06ca82d2 | ||
|
|
fcb01e73eb | ||
|
|
268f8d1f53 | ||
|
|
663fff0ae2 | ||
|
|
9d6af804bf | ||
|
|
f763f85213 | ||
|
|
e3e9374e2c | ||
|
|
c1a0c601e2 | ||
|
|
d656da8ccc | ||
|
|
1ca38d9748 | ||
|
|
5a6aa64570 | ||
|
|
b5f65e5332 | ||
|
|
cd6b43ea7a | ||
|
|
2236bbe7a3 | ||
|
|
cb0a944941 | ||
|
|
8a3d64033f | ||
|
|
03ee50e08f | ||
|
|
ca87ccd941 | ||
|
|
77352c495c | ||
|
|
0b06790da0 | ||
|
|
b43dc39ba4 | ||
|
|
2b71221194 | ||
|
|
8833d735a1 | ||
|
|
05a5223885 | ||
|
|
580d818aa9 | ||
|
|
587aa82021 | ||
|
|
12b88fce02 | ||
|
|
fc6c94c82a | ||
|
|
1add460678 | ||
|
|
4587c2b648 | ||
|
|
2236cdb302 |
6
.github/workflows/benchmark_tests.yml
vendored
6
.github/workflows/benchmark_tests.yml
vendored
@@ -382,6 +382,7 @@ jobs:
|
||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||
--env.type=robotwin \
|
||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -482,6 +483,7 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -693,6 +695,7 @@ jobs:
|
||||
--env.task=\"\$ROBOMME_TASKS\" \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0] \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -800,6 +803,7 @@ jobs:
|
||||
--env.type=libero_plus \
|
||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -900,6 +904,8 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--env.episode_length=50 \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
4
.github/workflows/documentation.yml
vendored
4
.github/workflows/documentation.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
||||
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -152,13 +152,14 @@ jobs:
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
||||
uv pip install \
|
||||
--torch-backend cpu \
|
||||
--index-url https://test.pypi.org/simple/ \
|
||||
--extra-index-url https://pypi.org/simple \
|
||||
--index-strategy unsafe-best-match \
|
||||
"lerobot[all]==$BASE_VERSION"
|
||||
else
|
||||
echo "Installing release version $VERSION from PyPI..."
|
||||
uv pip install "lerobot[all]==$VERSION"
|
||||
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
|
||||
fi
|
||||
- name: Check lerobot version
|
||||
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
||||
|
||||
16
.github/workflows/stale.yml
vendored
16
.github/workflows/stale.yml
vendored
@@ -19,19 +19,19 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Runs at 02:00
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
# schedule:
|
||||
# - cron: "0 2 * * *"
|
||||
|
||||
env:
|
||||
CLOSE_ISSUE_MESSAGE: >
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
This issue was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
This PR was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-issue-stale: 365
|
||||
days-before-issue-close: 30
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
days-before-pr-close: 30
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
@@ -232,6 +232,8 @@ Match the policy to the user's **GPU memory** and **time budget**. Numbers below
|
||||
|
||||
All policies typically train for **5–10 epochs** (see §7).
|
||||
|
||||
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
|
||||
|
||||
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
|
||||
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
|
||||
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
6
Makefile
6
Makefile
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -109,7 +109,7 @@ lerobot-train \
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
|
||||
|
||||
## Inference & Evaluation
|
||||
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
# Video benchmark
|
||||
|
||||
## Questions
|
||||
|
||||
What is the optimal trade-off between:
|
||||
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies,
|
||||
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
|
||||
|
||||
How to encode videos?
|
||||
|
||||
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
|
||||
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
|
||||
|
||||
## Variables
|
||||
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
### Encoding parameters
|
||||
|
||||
| parameter | values |
|
||||
| ----------- | ------------------------------------------------------------ |
|
||||
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
|
||||
| **pix_fmt** | `yuv444p`, `yuv420p` |
|
||||
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
|
||||
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
|
||||
|
||||
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
|
||||
|
||||
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
|
||||
|
||||
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
|
||||
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
|
||||
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
|
||||
|
||||
### Decoding parameters
|
||||
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
|
||||
- `pyav`
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
|
||||
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
|
||||
|
||||
- `1_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
|
||||
## Metrics
|
||||
|
||||
**Data compression ratio (lower is better)**
|
||||
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Loading time ratio (lower is better)**
|
||||
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average Mean Square Error (lower is better)**
|
||||
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Average Peak Signal to Noise Ratio (higher is better)**
|
||||
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
|
||||
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
<!-- **Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
|
||||
|
||||
## How the benchmark works
|
||||
|
||||
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
|
||||
|
||||
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
|
||||
This gives a unique set of encoding parameters which is used to encode the episode.
|
||||
|
||||
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
|
||||
|
||||
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
|
||||
These are then all concatenated to a single table ready for analysis.
|
||||
|
||||
## Caveats
|
||||
|
||||
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
|
||||
- `torchaudio`
|
||||
- `ffmpegio`
|
||||
- `decord`
|
||||
- `nvc`
|
||||
|
||||
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
|
||||
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
|
||||
|
||||
## Install
|
||||
|
||||
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
|
||||
|
||||
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
|
||||
|
||||
## Adding a video decoder
|
||||
|
||||
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
|
||||
You can easily add a new decoder to benchmark by adding it to this function in the script:
|
||||
|
||||
```diff
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
+ elif backend == ["your_decoder"]:
|
||||
+ return your_decoder_function(
|
||||
+ video_path, timestamps, tolerance_s, backend
|
||||
+ )
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
For a quick run, you can try these parameters:
|
||||
|
||||
```bash
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
--crf 10 40 None \
|
||||
--timestamps-modes 1_frame 2_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 5 \
|
||||
--num-workers 5 \
|
||||
--save-frames 0
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### Reproduce
|
||||
|
||||
We ran the benchmark with the following parameters:
|
||||
|
||||
```bash
|
||||
# h264 and h265 encodings
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
|
||||
# av1 encoding (only compatible with yuv420p and pyav decoder)
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
```
|
||||
|
||||
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
|
||||
|
||||
### Parameters selected for LeRobotDataset
|
||||
|
||||
Considering these results, we chose what we think is the best set of encoding parameter:
|
||||
|
||||
- vcodec: `libsvtav1`
|
||||
- pix-fmt: `yuv420p`
|
||||
- g: `2`
|
||||
- crf: `30`
|
||||
|
||||
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
|
||||
|
||||
### Summary
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
@@ -1,488 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Assess the performance of video decoding in various configurations.
|
||||
|
||||
This script will benchmark different video encoding and decoding parameters.
|
||||
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import itertools
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.utils import TimerManager
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
("vcodec", "libx264"),
|
||||
("pix_fmt", "yuv444p"),
|
||||
("g", 2),
|
||||
("crf", None),
|
||||
# TODO(aliberts): Add fastdecode
|
||||
# ("fastdecode", 0),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
|
||||
def parse_int_or_none(value) -> int | None:
|
||||
if value.lower() == "none":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
|
||||
|
||||
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
total_size = 0
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return torch.stack(frames)
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
case "1_frame":
|
||||
frame_indexes = [idx]
|
||||
case "2_frames":
|
||||
frame_indexes = [idx - 1, idx]
|
||||
case "2_frames_4_space":
|
||||
frame_indexes = [idx - 5, idx]
|
||||
case "6_frames":
|
||||
frame_indexes = [idx - i for i in range(6)][::-1]
|
||||
case _:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
timestamps_mode: str,
|
||||
backend: str,
|
||||
ep_num_images: int,
|
||||
fps: int,
|
||||
num_samples: int = 50,
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int, lock: Lock):
|
||||
time_benchmark = TimerManager(log=False)
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
"psnr_values": [],
|
||||
"ssim_values": [],
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark, lock:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
|
||||
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
|
||||
|
||||
return result
|
||||
|
||||
load_times_video_ms = []
|
||||
load_times_images_ms = []
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
ssim_values = []
|
||||
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
# Use a single shared lock for all worker threads
|
||||
shared_lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
psnr_values.extend(result["psnr_values"])
|
||||
ssim_values.extend(result["ssim_values"])
|
||||
mse_values.extend(result["mse_values"])
|
||||
|
||||
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
|
||||
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
|
||||
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
|
||||
|
||||
return {
|
||||
"avg_load_time_video_ms": avg_load_time_video_ms,
|
||||
"avg_load_time_images_ms": avg_load_time_images_ms,
|
||||
"video_images_load_time_ratio": video_images_load_time_ratio,
|
||||
"avg_mse": float(np.mean(mse_values)),
|
||||
"avg_psnr": float(np.mean(psnr_values)),
|
||||
"avg_ssim": float(np.mean(ssim_values)),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_encoding_decoding(
|
||||
dataset: LeRobotDataset,
|
||||
video_path: Path,
|
||||
imgs_dir: Path,
|
||||
encoding_cfg: dict,
|
||||
decoding_cfg: dict,
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
overwrite: bool = False,
|
||||
seed: int = 1337,
|
||||
) -> list[dict]:
|
||||
fps = dataset.fps
|
||||
|
||||
if overwrite or not video_path.is_file():
|
||||
tqdm.write(f"encoding {video_path}")
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
video_images_size_ratio = video_size_bytes / images_size_bytes
|
||||
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
timestamps_mode,
|
||||
backend,
|
||||
ep_num_images,
|
||||
fps,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
benchmark_row.update(
|
||||
**{
|
||||
"repo_id": dataset.repo_id,
|
||||
"resolution": f"{width} x {height}",
|
||||
"num_pixels": num_pixels,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"images_size_bytes": images_size_bytes,
|
||||
"video_images_size_ratio": video_images_size_ratio,
|
||||
"timestamps_mode": timestamps_mode,
|
||||
"backend": backend,
|
||||
},
|
||||
**encoding_cfg,
|
||||
)
|
||||
benchmark_table.append(benchmark_row)
|
||||
|
||||
return benchmark_table
|
||||
|
||||
|
||||
def main(
|
||||
output_dir: Path,
|
||||
repo_ids: list[str],
|
||||
vcodec: list[str],
|
||||
pix_fmt: list[str],
|
||||
g: list[int],
|
||||
crf: list[int],
|
||||
# fastdecode: list[int],
|
||||
timestamps_modes: list[str],
|
||||
backends: list[str],
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
):
|
||||
check_datasets_formats(repo_ids)
|
||||
encoding_benchmarks = {
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
# "fastdecode": fastdecode,
|
||||
}
|
||||
decoding_benchmarks = {
|
||||
"timestamps_modes": timestamps_modes,
|
||||
"backends": backends,
|
||||
}
|
||||
headers = ["repo_id", "resolution", "num_pixels"]
|
||||
headers += list(BASE_ENCODING.keys())
|
||||
headers += [
|
||||
"timestamps_mode",
|
||||
"backend",
|
||||
"video_size_bytes",
|
||||
"images_size_bytes",
|
||||
"video_images_size_ratio",
|
||||
"avg_load_time_video_ms",
|
||||
"avg_load_time_images_ms",
|
||||
"video_images_load_time_ratio",
|
||||
"avg_mse",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
]
|
||||
file_paths = []
|
||||
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
|
||||
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
|
||||
benchmark_table = []
|
||||
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for duet in [
|
||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
||||
]:
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for key, value in duet.items():
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
now = dt.datetime.now()
|
||||
csv_path = (
|
||||
output_dir
|
||||
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
|
||||
)
|
||||
benchmark_df.to_csv(csv_path, header=True, index=False)
|
||||
file_paths.append(csv_path)
|
||||
del benchmark_df
|
||||
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/video_benchmark"),
|
||||
help="Directory where the video benchmark outputs are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"lerobot/aloha_mobile_shrimp_image",
|
||||
"lerobot/paris_street",
|
||||
"lerobot/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["h264", "hevc", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["yuv444p", "yuv420p"],
|
||||
help="Pixel formats (chroma subsampling) to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
||||
help="Group of pictures sizes to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
|
||||
help="Constant rate factors to be tested.",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--fastdecode",
|
||||
# type=int,
|
||||
# nargs="*",
|
||||
# default=[0, 1],
|
||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
||||
# "For libx264 and libx265/hevc, only 1 is possible. "
|
||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--timestamps-modes",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
],
|
||||
help="Timestamps scenarios to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["torchcodec", "pyav"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of samples for each encoding x decoding config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of processes for parallelized sample processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
@@ -35,7 +35,7 @@ USER root
|
||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
||||
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
|
||||
@@ -18,9 +18,8 @@
|
||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
|
||||
# Configure the base image for CI with GPU access
|
||||
# TODO(Steven): Bump these versions
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
@@ -36,16 +35,13 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install Python, system dependencies, and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
build-essential git curl \
|
||||
libglib2.0-0 libgl1 libegl1 ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
title: LeRobot
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: cheat-sheet
|
||||
title: Cheat sheet
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
title: Bring Your Own Hardware
|
||||
- local: hilserl
|
||||
@@ -24,6 +26,12 @@
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: hardware_guide
|
||||
title: Compute Hardware Guide
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Compute & Hardware"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
@@ -31,8 +39,14 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
@@ -47,6 +61,8 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
@@ -61,6 +77,8 @@
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
@@ -129,6 +147,8 @@
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
- local: rebot_b601
|
||||
title: reBot B601-DM
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
@@ -138,10 +158,6 @@
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Supported Hardware"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
|
||||
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
```
|
||||
|
||||
198
docs/source/annotation_pipeline.mdx
Normal file
198
docs/source/annotation_pipeline.mdx
Normal file
@@ -0,0 +1,198 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` populates the two language columns introduced by the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — directly into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
A vocabulary-discovery phase derives a small canonical wording, then three
|
||||
modules write into a per-episode staging tree, then a single writer
|
||||
rewrites the data shards in place:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | -------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of canonical task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections`|
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections`|
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
The `plan` module is constrained to a **canonical vocabulary** discovered
|
||||
once per dataset by the `vocabulary` module (phase 0). It watches a few
|
||||
sample episode videos (`--vocabulary.sample_episodes`, default `3`) and
|
||||
asks the VLM to derive a small set of imperative subtask labels and
|
||||
first-person memory milestones that recur across the demos. The VLM
|
||||
picks the right number of entries itself based on what it sees in the
|
||||
clips — short pick-and-place demos get ~6 subtask labels, longer
|
||||
multi-step recipes get more. The result lands at
|
||||
`meta/canonical_vocabulary.json` (human-readable / hand-editable) and
|
||||
is reused on every subsequent run. The `plan` module then constrains
|
||||
both subtask + memory generation to those exact strings — the
|
||||
downstream low-level policy sees a small, repeatable target
|
||||
distribution instead of thousands of LLM paraphrases. Disable with
|
||||
`--vocabulary.enabled=False` to fall back to free-form generation.
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet — the tool
|
||||
catalog lives at `meta/info.json["tools"]` instead (see
|
||||
[Tools](./tools)). After every annotation run the pipeline ensures the
|
||||
canonical `say` schema is present in that list, preserving any tools the
|
||||
user pre-declared.
|
||||
|
||||
If you want to declare additional tools for a dataset before annotation
|
||||
runs, edit `meta/info.json["tools"]` directly — the pipeline preserves
|
||||
anything already there. Implementations of those tools live under
|
||||
`src/lerobot/tools/`; one file per tool, registered via
|
||||
`TOOL_REGISTRY`. See the [Tools](./tools) doc for the authoring guide.
|
||||
|
||||
## Running locally
|
||||
|
||||
Install the extra and invoke the console script. Episode-level
|
||||
concurrency comes from `--executor.episode_parallelism` (default 16);
|
||||
that is the only knob the in-process executor exposes.
|
||||
|
||||
```bash
|
||||
uv sync --extra annotations
|
||||
uv run lerobot-annotate \
|
||||
--root=/path/to/dataset \
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
The pipeline attaches actual camera footage to every `plan` /
|
||||
`interjections` / `vqa` prompt by default, decoded from the dataset's
|
||||
first `observation.images.*` stream. Override with
|
||||
`--vlm.camera_key=observation.images.<name>` to pin a specific
|
||||
viewpoint. Datasets with no video tracks fall back to text-only prompts
|
||||
automatically.
|
||||
|
||||
**The `plan` module sees the whole episode as one video block.** Subtask
|
||||
decomposition gets a `{"type":"video", "video":[<frames>]}` block
|
||||
covering the entire demonstration; Qwen-VL pools temporally on its own
|
||||
and decides where to cut. There is no keyframe stride or count knob —
|
||||
`--plan.max_video_frames` (default 128) only caps the frames packed
|
||||
into the video block as a model-capacity bound. The `interjections`
|
||||
module attaches a short window of frames straddling the interjection
|
||||
timestamp. The `vqa` module grounds each VQA pair on a single frame —
|
||||
its `--vqa.K` knob sets how many consecutive frames each emission tick
|
||||
anchors, and every anchored frame gets its own VQA pair on that one
|
||||
frame (there is no per-pair frame window).
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Distributed annotation is delegated to
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs). The repo
|
||||
ships a launcher script you copy and edit for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`examples/annotations/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
spawns one `h200x2` job that:
|
||||
|
||||
1. installs the branch under test plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) for the chosen model,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
via `lerobot-annotate`,
|
||||
4. uploads the annotated dataset to `--push_to_hub`.
|
||||
|
||||
To target a different dataset, model, or hub repo, edit the `CMD` block
|
||||
inside the script — every flag in there maps directly onto a CLI flag of
|
||||
`lerobot-annotate` (see `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Style-to-recipe consumer mapping
|
||||
|
||||
The pipeline's outputs are designed to be consumed by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)) — typically:
|
||||
|
||||
- low-level / high-level / memory-update branches consume
|
||||
`subtask`/`plan`/`memory` from `language_persistent`.
|
||||
- An interjection-response branch consumes `interjection` events plus
|
||||
the paired speech atom (merged into one assistant target turn via
|
||||
`tool_calls_from`) and the same-timestamp `plan` refresh.
|
||||
- A VQA branch consumes the `(vqa, user)` and `(vqa, assistant)` pairs
|
||||
from `language_events`.
|
||||
|
||||
## Why the design splits state from events
|
||||
|
||||
Two things drive the scope:
|
||||
|
||||
1. **Persistent state vs exact-event split.** Persistent rows
|
||||
(`subtask`, `plan`, `memory`) broadcast per episode and answer "what
|
||||
state is in force at this frame?". Event rows (`interjection`, `vqa`,
|
||||
speech) only appear on the exact frame whose timestamp matches the
|
||||
emission. The pipeline writes timestamps taken straight from the
|
||||
source parquet — no floating-point recomputation.
|
||||
2. **One Qwen-VL pass.** All three modules share a single VLM client
|
||||
(vLLM if available, transformers fallback) so the cost is one model
|
||||
load per dataset, not three.
|
||||
|
||||
## Module independence and staged reruns
|
||||
|
||||
Each module writes its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
|
||||
prompt iteration cheap — re-running one module overwrites only its own
|
||||
JSONL file before the writer composes the final parquet. Modules can be
|
||||
disabled via `--plan.enabled=false` (and likewise `--interjections.enabled`
|
||||
/ `--vqa.enabled`) to
|
||||
test them in isolation.
|
||||
|
||||
## Validation/report checks before final write
|
||||
|
||||
Before the writer runs, `StagingValidator` checks:
|
||||
|
||||
- exact frame-timestamp alignment for every event row;
|
||||
- no orphan speech / interjection pairs;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (warning, not error);
|
||||
- VQA assistant `content` parses as JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row routes to the column dictated by `column_for_style(style)`.
|
||||
|
||||
Errors abort the writer (`--skip_validation=true` overrides for debugging).
|
||||
|
||||
## Paper inspirations per module
|
||||
|
||||
- **`plan` module — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
atom granularity ("pick up one piece of lettuce", "place bowl to box");
|
||||
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
|
||||
what" detail.
|
||||
- **`plan` module — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||
compression directive: keep only minimal relevant information; functional
|
||||
outcomes preserved, specific attributes dropped.
|
||||
- **`interjections` module.** Hi Robot scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
|
||||
arguments:{text:...}}}]`).
|
||||
- **`vqa` module.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies ([Zhao 2025](https://arxiv.org/abs/2509.07626))
|
||||
multi-abstraction grounding. Pi0.7 also grounds answers across
|
||||
multiple abstraction levels.
|
||||
|
||||
Future maintainers should adjust the prompt templates in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
|
||||
references rather than rewriting from scratch.
|
||||
|
||||
## Compute and list-size estimates
|
||||
|
||||
Per episode, the pipeline issues O(`max_steps`) `plan`-module calls,
|
||||
O(`max_interjections_per_episode`) `interjections`-module calls, and
|
||||
O(`vqa_emission_hz × episode_seconds`) `vqa`-module calls. With defaults
|
||||
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
|
||||
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
|
||||
KB at most (parquet dictionary-encodes one entry per episode);
|
||||
`language_events` is empty on most frames and is bounded by the number of
|
||||
emissions, not `num_frames × num_emissions`.
|
||||
|
||||
## Reproducibility via seed and prompt hashes
|
||||
|
||||
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
|
||||
timestamps and VQA question types. Combined with the deterministic prompt
|
||||
templates checked into `prompts/`, two runs at the same seed against the
|
||||
same dataset and the same model checkpoint produce byte-identical staging
|
||||
artifacts. Prompt edits are recorded by file hash; future tooling can pin
|
||||
expected `(seed, prompt_hash)` pairs into the dataset card.
|
||||
@@ -1,60 +1,37 @@
|
||||
# Bring Your Own Policies
|
||||
# Adding a Policy
|
||||
|
||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
|
||||
|
||||
## Step 1: Create a Policy Package
|
||||
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
|
||||
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
|
||||
|
||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
|
||||
|
||||
### Package Structure
|
||||
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_custom_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_custom_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_custom_policy.py
|
||||
├── modeling_my_custom_policy.py
|
||||
└── processor_my_custom_policy.py
|
||||
```
|
||||
---
|
||||
|
||||
### Package Configuration
|
||||
## Anatomy of a policy
|
||||
|
||||
Set up your `pyproject.toml`:
|
||||
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_custom_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
### Configuration class
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
||||
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
# configuration_my_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig
|
||||
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@PreTrainedConfig.register_subclass("my_policy")
|
||||
@dataclass
|
||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyCustomPolicy.
|
||||
class MyPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyPolicy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of observation steps to use as input
|
||||
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
raise ValueError("n_action_steps cannot exceed horizon")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
"""Validate input/output feature compatibility.
|
||||
|
||||
Call this explicitly from your policy's __init__ — the base class does not.
|
||||
"""
|
||||
if not self.image_features:
|
||||
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
||||
raise ValueError("MyPolicy requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
||||
raise ValueError("MyPolicy requires 'action' in output_features.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
||||
"""
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns."""
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
return None
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
|
||||
|
||||
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
||||
### Policy class
|
||||
|
||||
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
# modeling_my_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Any
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_custom_policy"
|
||||
class MyPolicy(PreTrainedPolicy):
|
||||
config_class = MyPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features() # not called automatically by the base class
|
||||
self.config = config
|
||||
self.model = ... # your nn.Module here
|
||||
|
||||
def reset(self):
|
||||
"""Reset episode state."""
|
||||
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
|
||||
...
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
|
||||
...
|
||||
|
||||
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Return a single action for the current timestep (called at inference)."""
|
||||
"""Return a single action for the current timestep (called every step at inference)."""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
|
||||
"""Compute the training loss.
|
||||
|
||||
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
|
||||
logging-friendly Python natives (no tensors with gradients).
|
||||
|
||||
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||
timesteps padded because the episode ended before `horizon` steps, you
|
||||
timesteps padded because the episode ended before `horizon` steps; you
|
||||
can exclude those from your loss.
|
||||
"""
|
||||
actions = batch[ACTION]
|
||||
action_is_pad = batch.get("action_is_pad")
|
||||
...
|
||||
return {"loss": ...}
|
||||
return loss, {"some_loss_component": some_loss_component.item()}
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
The methods called by the train/eval loops:
|
||||
|
||||
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
| Method | Used by | What it does |
|
||||
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
|
||||
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
|
||||
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
|
||||
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
|
||||
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
|
||||
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
|
||||
|
||||
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
|
||||
|
||||
### Processor functions
|
||||
|
||||
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
# processor_my_policy.py
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
def make_my_policy_pre_post_processors(
|
||||
config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
**Important — function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
|
||||
## Step 5: Package Initialization
|
||||
---
|
||||
|
||||
Expose your classes in the package's `__init__.py`:
|
||||
## Path A: Out-of-tree plugin
|
||||
|
||||
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
|
||||
|
||||
### Package structure
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_policy.py
|
||||
├── modeling_my_policy.py
|
||||
└── processor_my_policy.py
|
||||
```
|
||||
|
||||
### `pyproject.toml`
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
### Package `__init__.py`
|
||||
|
||||
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
@@ -204,44 +239,148 @@ except ImportError:
|
||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||
)
|
||||
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .modeling_my_custom_policy import MyCustomPolicy
|
||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
from .modeling_my_policy import MyPolicy
|
||||
from .processor_my_policy import make_my_policy_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"MyCustomPolicyConfig",
|
||||
"MyCustomPolicy",
|
||||
"make_my_custom_policy_pre_post_processors",
|
||||
"MyPolicyConfig",
|
||||
"MyPolicy",
|
||||
"make_my_policy_pre_post_processors",
|
||||
]
|
||||
```
|
||||
|
||||
## Step 6: Installation and Usage
|
||||
|
||||
### Install Your Policy Package
|
||||
### Install and use
|
||||
|
||||
```bash
|
||||
cd lerobot_policy_my_custom_policy
|
||||
cd lerobot_policy_my_policy
|
||||
pip install -e .
|
||||
|
||||
# Or install from PyPI if published
|
||||
pip install lerobot_policy_my_custom_policy
|
||||
pip install lerobot_policy_my_policy
|
||||
```
|
||||
|
||||
### Use Your Policy
|
||||
|
||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type my_custom_policy \
|
||||
--policy.type my_policy \
|
||||
--env.type pusht \
|
||||
--steps 200000
|
||||
```
|
||||
|
||||
## Examples and Community Contributions
|
||||
---
|
||||
|
||||
## Path B: Contributing in-tree
|
||||
|
||||
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
|
||||
|
||||
### In-tree layout
|
||||
|
||||
```
|
||||
src/lerobot/policies/my_policy/
|
||||
├── __init__.py # re-exports config + modeling + processor factory
|
||||
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
|
||||
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
|
||||
├── processor_my_policy.py # make_my_policy_pre_post_processors
|
||||
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
|
||||
```
|
||||
|
||||
Two notes:
|
||||
|
||||
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
|
||||
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
|
||||
|
||||
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
|
||||
|
||||
### Wiring
|
||||
|
||||
Three places need to know about your policy. All by name.
|
||||
|
||||
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
|
||||
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
|
||||
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
|
||||
|
||||
Mirror an existing policy that's structurally similar to yours; the diff is small.
|
||||
|
||||
### Heavy / optional dependencies
|
||||
|
||||
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
|
||||
|
||||
```python
|
||||
from typing import TYPE_CHECKING
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
else:
|
||||
DDIMScheduler = None # keeps the symbol bindable at import time
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(self, config):
|
||||
require_package("diffusers", extra="diffusion")
|
||||
super().__init__(config)
|
||||
...
|
||||
```
|
||||
|
||||
This way:
|
||||
|
||||
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
|
||||
- Type checkers see the real symbol.
|
||||
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
|
||||
|
||||
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
|
||||
|
||||
### Benchmarks and a published checkpoint
|
||||
|
||||
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
|
||||
|
||||
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
|
||||
|
||||
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
|
||||
|
||||
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
|
||||
|
||||
```markdown
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with `lerobot/<policy>_libero`:
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| -------------- | -----------: | ---------: |
|
||||
| libero_spatial | 87.5% | 50 |
|
||||
| libero_object | 93.0% | 50 |
|
||||
| libero_goal | 81.5% | 50 |
|
||||
| libero_10 | 62.0% | 50 |
|
||||
| **average** | **81.0%** | 200 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
|
||||
```
|
||||
|
||||
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
|
||||
|
||||
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
|
||||
|
||||
### PR checklist
|
||||
|
||||
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
|
||||
|
||||
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
|
||||
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
|
||||
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
|
||||
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
|
||||
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
|
||||
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
|
||||
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
|
||||
|
||||
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
|
||||
|
||||
---
|
||||
|
||||
## Examples and community contributions
|
||||
|
||||
Check out these example policy implementations:
|
||||
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) — Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
|
||||
Share your policy implementations with the community! 🤗
|
||||
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
|
||||
|
||||
139
docs/source/cheat-sheet.mdx
Normal file
139
docs/source/cheat-sheet.mdx
Normal file
@@ -0,0 +1,139 @@
|
||||
# Cheat sheet
|
||||
|
||||
All of the LeRobot commands in one place. If you forgot how to use a specific command or want to learn about a new one you can do it here.
|
||||
|
||||
> [!WARNING]
|
||||
> For all of the commands listed below remember to change the ports/names/ids to your own values!
|
||||
|
||||
> [!TIP]
|
||||
> Another great way to look at all the commands and get them configured for your specific setup is to use this [Jupyter Notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb).
|
||||
|
||||
### Setup and installation
|
||||
|
||||
For installation please look at [LeRobot Installation](https://huggingface.co/docs/lerobot/main/en/installation).
|
||||
|
||||
### Useful tools
|
||||
|
||||
###### Find port
|
||||
|
||||
Use this to identify which serial ports your robots are connected to. Follow the instructions in your terminal: you will be asked to unplug the USB cable and press Enter. The script will then detect and print the correct serial port for that robot.
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
###### Find cameras
|
||||
|
||||
Quickly find camera indices and verify their output. This command prints camera information to the terminal and saves test frames from each detected camera to `lerobot/outputs/captured_images`
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
In most cases you will need to perform calibration just once for each robot and teleoperation device. Before performing the calibration make sure that all the joints are roughly in the middle position.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm
|
||||
```
|
||||
|
||||
Make sure that you use the same IDs used during calibration later for the other scripts. That's how LeRobot finds the calibration files.
|
||||
|
||||
### Teleoperation
|
||||
|
||||
Teleoperating with two cameras and displaying the data with Rerun.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
### Recording a dataset
|
||||
|
||||
The dataset is automatically uploaded to the server and saved under repo_id, make sure you are logged in to your HF account with CLI:
|
||||
`hf auth login`
|
||||
|
||||
You can get the token from: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--dataset.num_episodes=30 \
|
||||
--dataset.single_task="put the red brick in a bowl" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
While collecting the dataset you can control the process with your keyboard:
|
||||
Control the data recording flow using keyboard shortcuts:
|
||||
|
||||
- Press **Right Arrow (`→`)**: Save episode and move to the next.
|
||||
- Press **Left Arrow (`←`)**: Delete current episode and retry.
|
||||
- Press **Escape (`ESC`)**: Stop, encode videos, and upload.
|
||||
|
||||
### Training
|
||||
|
||||
Depending on your hardware training the policy might take a few hours. That's how you train simple `ACT` policy:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--job_name=act_so101_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
- Policy Types: `act`, `diffusion`, `smolvla`, `pi05`
|
||||
- Devices: `cuda` (NVIDIA), `mps` (Apple Silicon), `cpu`
|
||||
|
||||
If you want to fine-tune a specific model you can provide the path to the model. In this case path is enough and type can be skipped.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.path=username/the_policy_to_finetune \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
|
||||
|
||||
> [!TIP]
|
||||
> If you are using the previous release V0.5.1 instead of `lerobot-rollout` you need to use `lerobot-record`. More information [here](https://huggingface.co/docs/lerobot/v0.5.1/en/il_robots#run-inference-and-evaluate-your-policy).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
@@ -1,277 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -194,7 +194,7 @@ lerobot-record \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
168
docs/source/eo1.mdx
Normal file
168
docs/source/eo1.mdx
Normal file
@@ -0,0 +1,168 @@
|
||||
# EO-1
|
||||
|
||||
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||
alt="An overview of EO-1"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=eo1` configuration through LeRobot
|
||||
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||
|
||||
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EO-1 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1]"
|
||||
```
|
||||
|
||||
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1,libero]"
|
||||
```
|
||||
|
||||
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EO-1 expects a LeRobot dataset with:
|
||||
|
||||
- At least one visual observation, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction through the dataset `task` field
|
||||
|
||||
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=eo1
|
||||
```
|
||||
|
||||
By default, a new EO-1 policy initializes its backbone from:
|
||||
|
||||
```python
|
||||
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||
```
|
||||
|
||||
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-eo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=eo1 \
|
||||
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.attn_implementation=sdpa \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--output_dir=./outputs/eo1_training \
|
||||
--job_name=eo1_training \
|
||||
--steps=300000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||
|
||||
## Evaluation
|
||||
|
||||
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20
|
||||
```
|
||||
|
||||
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Processing
|
||||
|
||||
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||
|
||||
### State and Action Dimensions
|
||||
|
||||
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||
|
||||
### Attention Backend
|
||||
|
||||
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||
|
||||
## References
|
||||
|
||||
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{eo1,
|
||||
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||
journal={arXiv preprint},
|
||||
year={2025},
|
||||
url={https://arxiv.org/abs/2508.21112}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -119,14 +121,12 @@ lerobot-record \
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
--duration=600
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
98
docs/source/hardware_guide.mdx
Normal file
98
docs/source/hardware_guide.mdx
Normal file
@@ -0,0 +1,98 @@
|
||||
# Compute HW Guide for LeRobot Training
|
||||
|
||||
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
|
||||
|
||||
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
|
||||
|
||||
## Memory by policy group
|
||||
|
||||
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30–100% over a forward+backward pass alone.
|
||||
|
||||
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
|
||||
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
|
||||
| Light BC | `act`, `vqbet`, `tdmpc` | ~2–6GB | Laptop GPU (RTX 3060), L4, A10G |
|
||||
| Diffusion | `diffusion`, `multi_task_dit` | ~8–14GB | RTX 4070+ / L4 / A10G |
|
||||
| Small VLA | `smolvla` | ~10–16GB | RTX 4080+ / L4 / A10G |
|
||||
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~24–40GB | A100 40 GB+ (24 GB tight at BS 1) |
|
||||
| Multimodal | `groot`, `eo1` | ~24–40GB | A100 40 GB+ |
|
||||
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
|
||||
|
||||
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
|
||||
|
||||
## Training time
|
||||
|
||||
Robotics imitation learning typically converges in **5–10 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
|
||||
|
||||
```text
|
||||
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
|
||||
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
|
||||
total_steps = epochs × steps_per_epoch
|
||||
wall_clock ≈ total_steps × per_step_time
|
||||
```
|
||||
|
||||
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
|
||||
|
||||
### Common scenarios
|
||||
|
||||
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
|
||||
|
||||
| Setup | Policy | Batch | Wall-clock |
|
||||
| ------------------------------------ | -------------- | ----- | ---------: |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~30–60min |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~2–4h |
|
||||
| Single L4 / A10G (24 GB) | `act` | 8 | ~1–2h |
|
||||
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~3–6h |
|
||||
| Single A100 40 GB | `smolvla` | 16 | ~1–2h |
|
||||
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~4–8h |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~30–60min |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~1–2h |
|
||||
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~6–14h |
|
||||
|
||||
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
|
||||
|
||||
### Multi-GPU matters a lot
|
||||
|
||||
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
|
||||
|
||||
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
|
||||
|
||||
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
|
||||
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
|
||||
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
|
||||
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
|
||||
|
||||
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
|
||||
|
||||
### Schedule and checkpoints
|
||||
|
||||
If you shorten training (e.g. 5k–10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
|
||||
|
||||
## Where to run
|
||||
|
||||
VRAM is the first filter. Within a tier, pick by budget and availability — the `$`–`$$$$` columns are relative; check current pricing on the provider you actually use.
|
||||
|
||||
| Class | VRAM | Tier | Comfortable for |
|
||||
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
|
||||
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
|
||||
| L4 / A10G (cloud) | 24 GB | `$–$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
|
||||
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
|
||||
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
|
||||
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
|
||||
|
||||
### Hugging Face Jobs
|
||||
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
|
||||
|
||||
```bash
|
||||
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
|
||||
bash -c "nvidia-smi && lerobot-train \
|
||||
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
|
||||
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
|
||||
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators in the current `examples/hil` scripts:**
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
|
||||
## Script
|
||||
|
||||
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
|
||||
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | -------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | ---------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -108,11 +108,10 @@ python examples/hil/hil_data_collection.py \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--strategy.num_episodes=50 \
|
||||
--interpolation_multiplier=2
|
||||
```
|
||||
|
||||
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
|
||||
For models with high inference latency, enable RTC for smooth execution:
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=20 \
|
||||
--inference.rtc.max_guidance_weight=5.0 \
|
||||
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -136,11 +135,10 @@ python examples/hil/hil_data_collection.py \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--strategy.num_episodes=50 \
|
||||
--interpolation_multiplier=3
|
||||
```
|
||||
|
||||
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
|
||||
|
||||
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
||||
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
|
||||
|
||||
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
|
||||
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
||||
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
||||
```
|
||||
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
|
||||
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
|
||||
|
||||
**Example Configuration**
|
||||
|
||||
```json
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.20, 0.10],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.2, 0.1],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
|
||||
|
||||
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
|
||||
|
||||
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
|
||||
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
# bounds for the end-effector in x,y,z direction
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
# maximum step size for the end-effector in x,y,z direction
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
class HILSerlProcessorConfig:
|
||||
...
|
||||
# maximum gripper position that the gripper will be open at
|
||||
max_gripper_pos: float | None = 100.0
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
|
||||
|
||||
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
|
||||
|
||||
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
|
||||
**Collecting a Dataset for the reward classifier**
|
||||
|
||||
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
|
||||
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
||||
|
||||
@@ -658,7 +661,7 @@ Example configuration section for data collection:
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
@@ -671,7 +674,7 @@ Example configuration section for data collection:
|
||||
|
||||
**Reward Classifier Configuration**
|
||||
|
||||
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
|
||||
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
|
||||
|
||||
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
|
||||
- **model_type**: `"cnn"` or `"transformer"`
|
||||
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"root": null
|
||||
},
|
||||
"policy": {
|
||||
"reward_model": {
|
||||
"type": "reward_classifier",
|
||||
"model_name": "helper2424/resnet10",
|
||||
"model_type": "cnn",
|
||||
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"dropout_rate": 0.1,
|
||||
"learning_rate": 1e-4,
|
||||
"device": "cuda",
|
||||
"use_amp": true,
|
||||
"input_features": {
|
||||
"observation.images.front": {
|
||||
"type": "VISUAL",
|
||||
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
||||
|
||||
**Configuration Setup**
|
||||
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
|
||||
|
||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
||||
2. Set `dataset` to your cropped dataset
|
||||
3. Configure environment settings with crop parameters
|
||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
|
||||
3. Set `dataset` to your cropped dataset
|
||||
4. Configure environment settings with crop parameters
|
||||
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
**Starting the Learner**
|
||||
|
||||
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
||||
|
||||
Some configuration values have a disproportionate impact on training stability and speed:
|
||||
|
||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -278,6 +278,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -193,7 +207,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
@@ -202,10 +216,11 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
@@ -509,121 +610,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
## Run inference and evaluate your policy
|
||||
|
||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
<hfoption id="Base mode (no recording)">
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_awesome_leader_arm \
|
||||
--policy.path=${HF_USER}/my_policy
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Run the policy inference loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
<hfoption id="Sentry mode (with recording)">
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--duration=600
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
The `--strategy.type` flag selects the execution mode:
|
||||
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
||||
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
261
docs/source/inference.mdx
Normal file
261
docs/source/inference.mdx
Normal file
@@ -0,0 +1,261 @@
|
||||
# Policy Deployment (lerobot-rollout)
|
||||
|
||||
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
|
||||
|
||||
## Quick Start
|
||||
|
||||
No extra dependencies are needed beyond your robot and policy extras.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/act_koch_real \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--task="pick up cube" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
This runs the policy for 30 seconds with no recording.
|
||||
|
||||
---
|
||||
|
||||
## Strategies
|
||||
|
||||
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
|
||||
|
||||
### Base (`--strategy.type=base`)
|
||||
|
||||
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the box" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ---------------- | ------------------------------------------------------ |
|
||||
| `--duration` | Run time in seconds (0 = infinite) |
|
||||
| `--task` | Task description passed to the policy |
|
||||
| `--display_data` | Stream observations/actions to Rerun for visualization |
|
||||
|
||||
### Sentry (`--strategy.type=sentry`)
|
||||
|
||||
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
|
||||
|
||||
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=${HF_USER}/rollout_eval_data \
|
||||
--dataset.single_task="Put lego brick into the box" \
|
||||
--duration=3600
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ----------------------------------------------------------- |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
|
||||
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
|
||||
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
|
||||
|
||||
### Highlight (`--strategy.type=highlight`)
|
||||
|
||||
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=highlight \
|
||||
--strategy.ring_buffer_seconds=30 \
|
||||
--strategy.save_key=s \
|
||||
--strategy.push_key=h \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--dataset.repo_id=${HF_USER}/rollout_highlight_data \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ------------------ | -------------------------------------------------------- |
|
||||
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
|
||||
| `h` (configurable) | Push dataset to Hub |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ---------------------------------------------- |
|
||||
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
|
||||
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
|
||||
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
|
||||
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
|
||||
|
||||
### DAgger (`--strategy.type=dagger`)
|
||||
|
||||
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
|
||||
|
||||
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
|
||||
|
||||
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.num_episodes=20 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--teleop.type=openarm_mini \
|
||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
```
|
||||
|
||||
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.record_autonomous=true \
|
||||
--strategy.num_episodes=50 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/rollout_dagger_data \
|
||||
--dataset.single_task="Grasp the block"
|
||||
```
|
||||
|
||||
**Keyboard controls** (default input device):
|
||||
|
||||
| Key | Action |
|
||||
| ------- | ------------------------------------------- |
|
||||
| `Space` | Pause / resume policy execution |
|
||||
| `Tab` | Start / stop human correction |
|
||||
| `Enter` | Push dataset to Hub (corrections-only mode) |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------ | ------------------------------------------------------- |
|
||||
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
|
||||
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
|
||||
|
||||
### Sync (default)
|
||||
|
||||
One policy call per control tick. The main loop blocks until the action is computed.
|
||||
|
||||
Works with all policies. No extra flags needed.
|
||||
|
||||
### Real-Time Chunking (`--inference.type=rtc`)
|
||||
|
||||
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
|
||||
|
||||
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--policy.path=${HF_USER}/pi0_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Pick up the cube" \
|
||||
--duration=60 \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------------- | -------------------------------------------------------------- |
|
||||
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
|
||||
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
|
||||
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
|
||||
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
|
||||
|
||||
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
||||
|
||||
---
|
||||
|
||||
## Common Flags
|
||||
|
||||
| Flag | Description | Default |
|
||||
| --------------------------------- | ----------------------------------------------------------------- | ------- |
|
||||
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
|
||||
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
|
||||
| `--robot.port` | Serial port for the robot | -- |
|
||||
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
|
||||
| `--fps` | Control loop frequency | 30 |
|
||||
| `--duration` | Run time in seconds (0 = infinite) | 0 |
|
||||
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
|
||||
| `--task` | Task description (used when no dataset is provided) | -- |
|
||||
| `--display_data` | Stream telemetry to Rerun visualization | false |
|
||||
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
|
||||
| `--interpolation_multiplier` | Action interpolation factor | 1 |
|
||||
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
|
||||
| `--resume` | Resume a previous recording session | false |
|
||||
| `--play_sounds` | Vocal synthesis for events | true |
|
||||
|
||||
---
|
||||
|
||||
## Programmatic Usage
|
||||
|
||||
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||
|
||||
```python
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=my_robot_config,
|
||||
policy=my_policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=30,
|
||||
duration=60,
|
||||
task="my task",
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=my_custom_action_processor, # optional
|
||||
robot_observation_processor=my_custom_obs_processor, # optional
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
```
|
||||
|
||||
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
|
||||
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
|
||||
|
||||
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
||||
|
||||
### PyTorch CUDA variant (Linux only)
|
||||
|
||||
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
|
||||
<hfoptions id="cuda_variant">
|
||||
<hfoption id="uv-source">
|
||||
|
||||
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
|
||||
|
||||
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
|
||||
|
||||
To override for a different CUDA variant:
|
||||
|
||||
```bash
|
||||
uv pip install --force-reinstall torch torchvision \
|
||||
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="pip-conda">
|
||||
|
||||
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
|
||||
|
||||
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
|
||||
|
||||
To pick a specific CUDA variant:
|
||||
|
||||
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
|
||||
|
||||
```bash
|
||||
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
|
||||
pip install -e ".[all]" # source
|
||||
# — or —
|
||||
pip install lerobot # from PyPI
|
||||
```
|
||||
|
||||
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
|
||||
|
||||
```bash
|
||||
uv pip install --torch-backend cu128 lerobot
|
||||
```
|
||||
|
||||
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
147
docs/source/language_and_recipes.mdx
Normal file
147
docs/source/language_and_recipes.mdx
Normal file
@@ -0,0 +1,147 @@
|
||||
# Language columns and recipes
|
||||
|
||||
Most LeRobot datasets ship with a single `task` string per episode — fine for
|
||||
short, single-instruction skills, but not enough for the longer-horizon,
|
||||
multi-modal robot policies the field is moving toward (high-level planning,
|
||||
memory, interjections, VQA, tool use). To support those policies without
|
||||
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
|
||||
language columns and a small recipe layer that turns those rows into
|
||||
chat-style training samples on the fly.
|
||||
|
||||
The design splits cleanly into three layers:
|
||||
|
||||
1. **Data in the dataset** — language annotations stored next to frames in
|
||||
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
|
||||
and `language_events`). Datasets without these columns keep their existing
|
||||
behavior.
|
||||
2. **Recipe** — a YAML file that declares which annotation rows to bind and
|
||||
how to lay them out as chat turns (`role`, `content`, optional images,
|
||||
optional tool calls). Recipes are pure config; no Python required to add a
|
||||
new one.
|
||||
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
|
||||
recipe against the per-frame annotations and emits HF-style `messages` plus
|
||||
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
|
||||
that policy processors consume.
|
||||
|
||||
This page describes each layer in turn.
|
||||
|
||||
## Layer 1 — language columns in the dataset
|
||||
|
||||
The two optional columns live next to frame data in
|
||||
`data/chunk-*/file-*.parquet`:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float32 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||
the matching `observation.images.*` feature key. Rows of every other style —
|
||||
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||
enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
### Architecture
|
||||
|
||||
The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
### View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${vqa}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
if_present: vqa,
|
||||
}
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
@@ -10,6 +10,7 @@ This docs will guide you to:
|
||||
- Stream datasets without downloading using `StreamingLeRobotDataset`
|
||||
- Apply image transforms for data augmentation during training
|
||||
- Migrate existing `v2.1` datasets to `v3.0`
|
||||
- Experiment with other `LeRobotDataset` formats and implementations like Lance
|
||||
|
||||
## What’s new in `v3`
|
||||
|
||||
@@ -43,7 +44,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
@@ -315,3 +316,39 @@ Dataset v3.0 uses incremental parquet writing with buffered metadata for efficie
|
||||
- Ensures the dataset is valid for loading
|
||||
|
||||
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
|
||||
|
||||
## Other formats and implementations
|
||||
|
||||
### Lance
|
||||
|
||||
Lance is a useful format for multimodal AI datasets, especially for large-scale training requiring high performance IO and random access.
|
||||
|
||||
The `lerobot-lancedb` package implements `LeRobotLanceDataset` (for JPEG images) and `LeRobotLanceVideoDataset` (for mp4 videos).
|
||||
Those two storage layouts both subclass LeRobotDataset and can provide data loading speed ups.
|
||||
|
||||
`LeRobotLanceDataset` is a drop-in replacement for `LeRobotDataset`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot_lancedb import LeRobotLanceDataset, LeRobotLanceVideoDataset
|
||||
|
||||
cfg = DiffusionConfig(...)
|
||||
meta = LeRobotDatasetMetadata(root=local_dataset_path) # or use repo_id=... to load metadata from the Hub
|
||||
delta_timestamps = {...}
|
||||
|
||||
# Use LeRobotLanceDataset for image datasets
|
||||
dataset = LeRobotLanceDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
# Or use LeRobotLanceVideoDataset for video datasets:
|
||||
dataset = LeRobotLanceVideoDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
```
|
||||
|
||||
Join the discussion on [Github](https://github.com/huggingface/lerobot/issues/3608) and explore the `lerobot-lancedb` documentation [here](https://lancedb.github.io/lerobot-lancedb/).
|
||||
|
||||
@@ -28,13 +28,15 @@ lerobot-train \
|
||||
--steps=100000 \
|
||||
--batch_size=32 \
|
||||
--peft.method_type=LORA \
|
||||
--peft.r=64
|
||||
--peft.r=64 \
|
||||
--peft.lora_alpha=64
|
||||
```
|
||||
|
||||
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
||||
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
||||
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter, and the LoRA scaling factor with
|
||||
`--peft.lora_alpha` (where `scaling = lora_alpha / r`). The higher the rank
|
||||
the closer you get to full fine-tuning
|
||||
|
||||
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
||||
|
||||
@@ -161,7 +161,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
186
docs/source/rebot_b601.mdx
Normal file
186
docs/source/rebot_b601.mdx
Normal file
@@ -0,0 +1,186 @@
|
||||
# reBot B601-DM
|
||||
|
||||
[reBot B601-DM](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/) is an open-source, low-cost robot arm from Seeed Studio for embodied-AI and imitation learning. It comes as a **follower** arm (the `B601-DM`, a 6-DOF arm plus gripper driven by Damiao CAN motors) and a **leader** arm (the `StarArm102` / `reBot Arm 102`, driven by FashionStar UART smart servos) used to teleoperate it.
|
||||
|
||||
This page covers **calibration** and **teleoperation** for both single-arm and bimanual (dual-arm) setups.
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/b601dm_zeroposition.jpg"
|
||||
alt="reBot B601-DM follower arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/102_zeroposition.jpg"
|
||||
alt="reBot Arm 102 leader arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
_Left: the B601-DM follower at its zero position. Right: the reBot Arm 102 leader at its zero position. Images courtesy of [Seeed Studio](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/)._
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the reBot support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[rebot]"
|
||||
```
|
||||
|
||||
This pulls in `motorbridge` (CAN motor control for the B601-DM follower) and `motorbridge-smart-servo` (FashionStar UART servos for the reBot Arm 102 leader).
|
||||
|
||||
## Registered device types
|
||||
|
||||
| Type | Kind |
|
||||
| ------------------------ | -------------------------------------------- |
|
||||
| `rebot_b601_follower` | single-arm B601-DM follower robot |
|
||||
| `bi_rebot_b601_follower` | bimanual (dual-arm) follower robot |
|
||||
| `rebot_102_leader` | single-arm reBot Arm 102 leader teleoperator |
|
||||
| `bi_rebot_102_leader` | bimanual (dual-arm) leader teleoperator |
|
||||
|
||||
The bimanual types compose two single-arm instances and namespace each arm's
|
||||
observation/action keys with a `left_` / `right_` prefix. Per-arm settings are
|
||||
passed through nested `left_arm_config.*` / `right_arm_config.*` arguments.
|
||||
|
||||
## Find the USB ports
|
||||
|
||||
For each device, find the USB port associated with its motor bus using:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
On Linux, remove `brltty` (`sudo apt remove brltty`) so it does not hold the
|
||||
leader's USB serial port. You may also need to grant access to the serial
|
||||
devices: `sudo chmod 666 /dev/ttyACM* /dev/ttyUSB*`.
|
||||
</Tip>
|
||||
|
||||
## Calibration
|
||||
|
||||
Neither arm stores a persistent hardware calibration: every time it connects, the motors are re-zeroed against the pose the arm is physically holding. Calibration simply records that zero pose. When prompted, **manually move the arm to its zero position** (the default sit-down pose shown above, gripper fully closed) and press <kbd>ENTER</kbd>.
|
||||
|
||||
### Follower (B601-DM)
|
||||
|
||||
<hfoptions id="calibrate-follower">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
Connect the bimanual follower; calibration runs for the left arm, then the right arm.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao
|
||||
```
|
||||
|
||||
Per-arm calibration files are saved with `_left` / `_right` suffixes on the id.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader (reBot Arm 102)
|
||||
|
||||
<hfoptions id="calibrate-leader">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperation
|
||||
|
||||
Once both arms are calibrated, drive the follower with the leader. The follower talks to its CAN bus through a Damiao serial bridge (`can_adapter=damiao`, the default) or a SocketCAN adapter (`can_adapter=socketcan`). See the [OpenArm page](./openarm) for more details on the SocketCAN adapter configuration.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
The bimanual leader and follower reuse the single-arm classes; each arm is
|
||||
configured through nested `left_arm_config.*` / `right_arm_config.*` arguments,
|
||||
so a bimanual reBot Arm 102 leader drives a bimanual B601-DM follower.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip>
|
||||
The leader and follower share the same joint names (`shoulder_pan,
|
||||
shoulder_lift, elbow_flex, wrist_flex, wrist_yaw, wrist_roll, gripper`), so
|
||||
leader actions map directly onto the follower.
|
||||
</Tip>
|
||||
|
||||
If the motion of a joint is reversed, flip its sign in the leader's `joint_directions` (the gripper also carries a scale to widen its range to the follower):
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.joint_directions='{"shoulder_pan":-1,"shoulder_lift":-1,"elbow_flex":1,"wrist_flex":1,"wrist_yaw":1,"wrist_roll":-1,"gripper":-6}'
|
||||
```
|
||||
|
||||
## Recording datasets
|
||||
|
||||
Swap `lerobot-teleoperate` for `lerobot-record` (with the same `--robot.*` / `--teleop.*` arguments, plus `--dataset.*`) to record demonstrations for training. See [Imitation Learning for Robots](./il_robots) for the full workflow.
|
||||
|
||||
For hardware assembly and wiring, see the [Seeed Studio reBot wiki](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/).
|
||||
@@ -61,17 +61,6 @@ lerobot-eval \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
### Recording
|
||||
|
||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||
|
||||
```bash
|
||||
lerobot-record \ # When running inference
|
||||
--policy.path="<user>/smolVLA_finetuned" \
|
||||
... \
|
||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||
```
|
||||
|
||||
## Alternative: edit the policy config directly
|
||||
|
||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||
@@ -105,10 +94,10 @@ XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
| Goal | What to do |
|
||||
| --------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Rollout with different keys (inference) | `--rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
|
||||
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
### Using RTC with Pi0
|
||||
|
||||
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
|
||||
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||
|
||||
```python
|
||||
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
|
||||
## Testing RTC with a Real Robot
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
|
||||
# ... create plots
|
||||
```
|
||||
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
@@ -465,15 +465,15 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -488,12 +488,13 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
--sample_weighting.kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
--sample_weighting.kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
--sample_weighting.kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
@@ -550,8 +551,8 @@ accelerate launch \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -576,7 +577,7 @@ accelerate launch \
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `vcodec` | `--dataset.camera_encoder.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `30` | Max buffered frames per camera (~1s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
@@ -48,7 +48,7 @@ This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
### Backpressure and Frame Dropping
|
||||
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 30 frames). When the encoder can't keep up:
|
||||
|
||||
1. The queue fills up (consuming RAM)
|
||||
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
|
||||
@@ -82,15 +82,15 @@ Use HW encoding when:
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
@@ -100,15 +100,15 @@ Use HW encoding when:
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
@@ -146,7 +146,7 @@ On very constrained systems, streaming encoding may compete too heavily with the
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
lerobot-record --dataset.camera_encoder.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
210
docs/source/tools.mdx
Normal file
210
docs/source/tools.mdx
Normal file
@@ -0,0 +1,210 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives.
|
||||
2. How the annotation pipeline produces tool-call atoms.
|
||||
3. How to add your own tool.
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
|
||||
not part of the catalog layer described here; today this layer ships
|
||||
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
The catalog above describes _what can be called_. The actual _call_ — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${current_plan}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
tool_calls_from: speech,
|
||||
}
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text _and_ the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which is not part of the catalog layer shipped
|
||||
> today — those modules don't yet exist in the tree. Step 1 alone is
|
||||
> enough to make the tool visible to the chat template via
|
||||
> `meta.tools` so the model can learn to _generate_ the call;
|
||||
> executing the call at inference requires the runtime layer.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk _before_ running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Short label for the saved image."
|
||||
}
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
|
||||
only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py`:
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
If you want to use a tool _without_ writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to _generate_ the call. Steps 2 and 3 are only
|
||||
needed to actually _execute_ it at inference.
|
||||
@@ -274,7 +274,8 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Once trained, we recommend deploying policies using inference-time RTC:
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=your-username/your-repo-id \
|
||||
--policy.device=cuda \
|
||||
--robot.type=unitree_g1 \
|
||||
@@ -284,7 +285,7 @@ python examples/rtc/eval_with_real_robot.py \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--rtc.enabled=true
|
||||
--inference.type=rtc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
--operation.camera_encoder.vcodec libsvtav1 \
|
||||
--operation.camera_encoder.pix_fmt yuv420p \
|
||||
--operation.camera_encoder.g 2 \
|
||||
--operation.camera_encoder.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
@@ -147,11 +147,7 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
|
||||
117
docs/source/video_encoding_parameters.mdx
Normal file
117
docs/source/video_encoding_parameters.mdx
Normal file
@@ -0,0 +1,117 @@
|
||||
# Video encoding parameters
|
||||
|
||||
When video storage is enabled, LeRobot stores each camera stream as an **MP4** file instead of saving one image file per timestep. Video encoding compresses across time, which usually cuts dataset size and I/O compared to a pile of PNG, while keeping MP4 — a format every player and loader understands.
|
||||
|
||||
Encoding frames into an MP4 is a full FFmpeg pipeline: choice of encoder, pixel format, GOP/keyframes, quality vs. speed, and optional extra encoder flags. Most of these knobs are user-tunable through `camera_encoder`, a nested `VideoEncoderConfig` (`lerobot.configs.video.VideoEncoderConfig`) passed through PyAV.
|
||||
|
||||
You can set these parameters from the CLI with `--dataset.camera_encoder.<field>` (e.g. with `lerobot-record` or `lerobot-rollout`). The same block applies to every camera video stream in that run.
|
||||
|
||||
<Tip>
|
||||
Video storage must be on for `camera_encoder` to have any effect —
|
||||
`use_videos=True` in Python APIs, or `--dataset.video=true` on the CLI (the
|
||||
recording default). With video off, inputs stay as images and `camera_encoder`
|
||||
is ignored.
|
||||
</Tip>
|
||||
|
||||
For details on **when** frames are written vs. encoded (streaming vs. post-episode), queues, and other top-level `--dataset.*` switches, see [Streaming Video Encoding](./streaming_video_encoding). For an encoding-parameter comparison and experiments, see the [video-benchmark Space](https://huggingface.co/spaces/lerobot/video-benchmark).
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--dataset.camera_encoder.vcodec=h264 \
|
||||
--dataset.camera_encoder.preset=fast \
|
||||
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tuning parameters
|
||||
|
||||
<Tip warning={true}>
|
||||
The defaults are tuned to balance **compression ratio**, **visual quality**, and **decoding/seek speed** for typical robotics datasets. Changing them can affect both recording (CPU load, frame drops) and training (decoding throughput, image quality).
|
||||
|
||||
Only override these parameters if you have a specific reason to, and measure the impact on your pipeline before relying on the new settings.
|
||||
|
||||
</Tip>
|
||||
|
||||
All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| --------------- | ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vcodec` | `str` | `"libsvtav1"` | Video codec name. `"auto"` picks the first available hardware encoder from a fixed preference list, falling back to `libsvtav1`. |
|
||||
| `pix_fmt` | `str` | `"yuv420p"` | Output pixel format. Must be supported by the chosen codec in your FFmpeg build. |
|
||||
| `g` | `int` | `2` | GOP size — a keyframe every `g` frames. Emitted as FFmpeg option `g`. |
|
||||
| `crf` | `int` or `float` | `30` | Abstract quality value, mapped per codec (see the [mapping](#mapping-videoencoderconfig--ffmpeg-options) below). Lower → higher quality / larger output where the mapping is monotone. |
|
||||
| `preset` | `int` or `str` | `12` \* | Encoder speed preset; meaning depends on the codec. <br/>\* When unset and `vcodec=libsvtav1`, LeRobot defaults to `12`. |
|
||||
| `fast_decode` | `int` | `0` | `libsvtav1`: `0–2`, passed via `svtav1-params`. <br/>`h264` / `hevc` (software): if `>0`, sets `tune=fastdecode`. <br/>Other codecs: usually unused. |
|
||||
| `video_backend` | `str` | `"pyav"` | Only `"pyav"` is currently implemented for video encoding. |
|
||||
| `extra_options` | `dict` | `{}` | Extra FFmpeg or codec specific options merged after the structured fields above. Cannot override keys already set by those fields. |
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"observation.images.laptop": {
|
||||
"dtype": "video",
|
||||
"shape": [480, 640, 3],
|
||||
"info": {
|
||||
"video.height": 480,
|
||||
"video.width": 640,
|
||||
"video.codec": "h264",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
"video.fast_decode": 0,
|
||||
"video.video_backend": "pyav",
|
||||
"video.extra_options": { "tune": "film", "profile:v": "high", "bf": 2 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
This block is populated **once**, from the **first** episode. It assumes every
|
||||
episode in the dataset was encoded with the same `camera_encoder`. Changing
|
||||
encoder settings partway through a recording is not supported — the
|
||||
`info.json` will only reflect the parameters used for the first episode.
|
||||
</Tip>
|
||||
|
||||
---
|
||||
|
||||
## Merging datasets
|
||||
|
||||
When aggregating datasets with `merge_datasets`, video files are concatenated as-is (no re-encoding), and encoder fields in `info.json` are merged per-key:
|
||||
|
||||
- **Stream-derived fields must match** across sources: `video.codec`, `video.pix_fmt`, `video.height`, `video.width`, `video.fps`. Otherwise FFmpeg's concat demuxer fails.
|
||||
- **Encoder-tuning fields are merged loosely**: `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.extra_options`. If every source agrees, the value is kept; if not, it's set to `null` (or `{}` for `video.extra_options`) and a warning is logged.
|
||||
@@ -220,7 +220,7 @@ REAL_DIM = 12
|
||||
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||
```
|
||||
|
||||
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||
See the [action_hub.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||
|
||||
#### Auto Action Mode (Recommended)
|
||||
|
||||
@@ -519,9 +519,9 @@ If you use X-VLA in your research, please cite:
|
||||
|
||||
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
||||
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
||||
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py)
|
||||
- [Processor Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/processor_xvla.py)
|
||||
- [Model Configuration](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
89
examples/annotations/run_hf_job.py
Normal file
89
examples/annotations/run_hf_job.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||
|
||||
Spawns one ``h200x2`` job that:
|
||||
|
||||
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (phase 0 canonical-vocabulary discovery is
|
||||
disabled — each episode generates its own subtasks + memory),
|
||||
4. uploads the annotated dataset to ``--dest_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Re-enable phase 0 with ``--vocabulary.enabled=true`` (optionally
|
||||
``--vocabulary.sample_episodes=N``) when the dataset is homogeneous
|
||||
enough to share one subtask + memory vocabulary across all episodes.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` below to point at your own dataset / target hub repo.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=imstevenpmwork/super_poulain_draft "
|
||||
"--dest_repo_id=pepijn223/super_poulain_vocab "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||
"--vlm.parallel_servers=2 "
|
||||
"--vlm.num_gpus=2 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=128 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
"--vlm.temperature=0.7 "
|
||||
"--executor.episode_parallelism=16 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
"--vlm.camera_key=observation.images.wrist "
|
||||
# Phase 0 — canonical vocabulary discovery DISABLED by default.
|
||||
# Heterogeneous datasets (different tasks/scenes across episodes)
|
||||
# don't share a single small subtask + memory vocabulary, so each
|
||||
# episode generates its subtasks + memory free-form. Flip to
|
||||
# ``--vocabulary.enabled=true`` (optionally ``--vocabulary.sample_episodes=N``)
|
||||
# for homogeneous datasets where a shared canonical vocabulary
|
||||
# helps the downstream policy.
|
||||
"--vocabulary.enabled=false "
|
||||
# Phase 1 — plan module (subtasks + plan + memory + task_aug).
|
||||
"--plan.frames_per_second=1.0 "
|
||||
"--plan.use_video_url=true "
|
||||
"--plan.use_video_url_fps=1.0 "
|
||||
"--plan.derive_task_from_video=always "
|
||||
"--plan.n_task_rephrasings=30 "
|
||||
# Phase 2 — interjections + speech.
|
||||
"--interjections.max_interjections_per_episode=6 "
|
||||
# Phase 4 — general VQA.
|
||||
"--vqa.K=3 "
|
||||
"--vqa.vqa_emission_hz=1.0"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200x2",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
parquet_path = local_path / progress_file
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
logging.warning("%s not found", progress_file)
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -576,6 +585,7 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -585,6 +595,8 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -592,7 +604,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -600,9 +612,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -627,7 +639,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -658,6 +670,15 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -670,6 +691,7 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,226 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
vcodec: str = "auto"
|
||||
streaming_encoding: bool = True
|
||||
encoder_queue_maxsize: int = 30
|
||||
encoder_threads: int | None = None
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if hasattr(teleop, "disable_torque"):
|
||||
teleop.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if hasattr(teleop, "enable_torque"):
|
||||
teleop.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"resume_policy": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
logger.info("[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "p":
|
||||
if events["policy_paused"] or events["correction_active"]:
|
||||
logger.info("[HIL] Resuming policy...")
|
||||
events["resume_policy"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("[HIL] End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("[HIL] Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
logger.info(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
logger.info("[HIL] RESET")
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
logger.info("Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["resume_policy"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
|
||||
logger.info(
|
||||
"%s\n Controls:\n"
|
||||
" SPACE - Pause policy\n"
|
||||
" c - Take control\n"
|
||||
" p - Resume policy after pause/correction\n"
|
||||
" → - End episode\n"
|
||||
" ESC - Stop and push to hub",
|
||||
mode,
|
||||
)
|
||||
@@ -14,17 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
@@ -83,43 +90,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot
|
||||
robot_action_to_send = robot_action_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
|
||||
@@ -45,9 +45,6 @@ def main():
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
@@ -77,6 +74,10 @@ def main():
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = (
|
||||
make_default_processors()
|
||||
)
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
@@ -87,14 +88,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -106,13 +107,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
77
examples/lekiwi/rollout.py
Normal file
77
examples/lekiwi/rollout.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained policy on LeKiwi without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). For a CLI entry point with the same capabilities plus
|
||||
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
||||
"""
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
|
||||
# by ``build_rollout_context`` to reload the full model.
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
# Assemble the rollout config: base strategy (no recording) + sync inference.
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Build the context (connects robot, loads policy, wires the inference strategy).
|
||||
# No custom processors here — LeKiwi runs on raw joint features.
|
||||
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -80,7 +80,7 @@
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Dataset\n",
|
||||
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
|
||||
"HF_USER = \"your_hf_username\" # `hf auth whoami` to find your username\n",
|
||||
"DATASET_NAME = \"my_so101_dataset\"\n",
|
||||
"TASK_DESCRIPTION = \"pick and place the block\"\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
@@ -291,7 +291,34 @@
|
||||
"\n",
|
||||
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
|
||||
"\n",
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details.\n",
|
||||
"\n",
|
||||
"Recently ```lerobot-rollout``` was introduced, you can [read more about it here](https://huggingface.co/docs/lerobot/main/en/il_robots?eval=Base+mode+%28no+recording%29#run-inference-and-evaluate-your-policy)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-rollout\",\n",
|
||||
" \"--strategy.type=base\",\n",
|
||||
" f\"--policy.path={POLICY_PATH}\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f'--task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--duration=60\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"if you are using the V0.5.1 release you should use ```lerobot-record``` instead of rollout"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
136
examples/omx/README.md
Normal file
136
examples/omx/README.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# OMX Follower — Cube Pick And Place Example
|
||||
|
||||
This is an example of what is possible to do with LeRobot on a physical setup.
|
||||
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
|
||||
|
||||
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
|
||||
|
||||
## Hardware
|
||||
|
||||
| Component | Value |
|
||||
| --------- | ------------------------------------ |
|
||||
| Robot | OMX Follower |
|
||||
| Cameras | 2× OpenCV cameras (wrist + top-down) |
|
||||
|
||||
## Scripts
|
||||
|
||||
| Script | Purpose |
|
||||
| ---------------------- | --------------------------------------------------------------- |
|
||||
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
|
||||
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
|
||||
|
||||
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
|
||||
|
||||
```bash
|
||||
export ROBOT_PORT=/dev/ttyACM0
|
||||
export TELEOP_PORT=/dev/ttyACM1
|
||||
export HF_USERNAME=<your_hf_username>
|
||||
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
|
||||
```
|
||||
|
||||
## Step 1 — Collect Data
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=$TELEOP_PORT \
|
||||
--teleop.id=omx_leader \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
### Bonus Auto-Collect script
|
||||
|
||||
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
|
||||
|
||||
```bash
|
||||
python -m examples.omx.record_grab \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Each episode:
|
||||
|
||||
1. The arm grabs the cube from the center of the workspace and places it at a random position.
|
||||
2. The arm returns to HOME.
|
||||
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
|
||||
|
||||
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
|
||||
|
||||
## Step 2 — Train
|
||||
|
||||
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/omx_pickandplace_act \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
|
||||
--steps=20000 \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
|
||||
|
||||
## Step 3 — Rollout
|
||||
|
||||
Use the `lerobot-rollout` CLI with base strategy:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
```
|
||||
|
||||
For continuous recording with automatic upload (sentry mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=10 \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
|
||||
```
|
||||
|
||||
## Environment Reset Utility
|
||||
|
||||
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
|
||||
|
||||
`reset_environment.py` can be run standalone to prepare the workspace:
|
||||
|
||||
```bash
|
||||
# Grab cube + place it at a random position on the left side
|
||||
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
|
||||
```
|
||||
|
||||
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
|
||||
422
examples/omx/record_grab.py
Normal file
422
examples/omx/record_grab.py
Normal file
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-record grab episodes for the OMX robot arm.
|
||||
|
||||
Each episode cycle:
|
||||
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
|
||||
2. HOME — return arm to home with gripper open
|
||||
3. record_grab — execute a targeted grab to the stored position while recording
|
||||
observations + actions to a LeRobotDataset
|
||||
|
||||
Usage (run from repo root):
|
||||
python -m examples.omx.record_grab \\
|
||||
--robot.type=omx_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--robot.id=omx_follower \\
|
||||
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
|
||||
--dataset.repo_id=<hf_username>/<dataset_name> \\
|
||||
--dataset.root=data/omx_grab \\
|
||||
--dataset.num_episodes=50 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import CameraConfig # noqa: F401
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
)
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots import RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.omx_follower import OmxFollower
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from .reset_environment import (
|
||||
APPROACH_SPEED,
|
||||
GRIPPER_CLOSE_POS,
|
||||
HOME_POSE,
|
||||
PUSH_END_ELBOW_FLEX,
|
||||
PUSH_END_SHOULDER_LIFT,
|
||||
PUSH_START_ELBOW_FLEX,
|
||||
PUSH_START_SHOULDER_LIFT,
|
||||
array_to_pose,
|
||||
grab_cube,
|
||||
horizontal_wrist_flex,
|
||||
move_to_pose,
|
||||
place_cube,
|
||||
pose_to_array,
|
||||
)
|
||||
|
||||
# ── Grab-episode motion parameters ────────────────────────────────────────────
|
||||
|
||||
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
|
||||
GRAB_RAISE_SL_OFFSET = 20.0
|
||||
GRAB_LOWER_SPEED = 20.0
|
||||
RECORD_SPEED = 30.0
|
||||
|
||||
# Pose the arm travels to after closing the gripper (cube held).
|
||||
GRAB_CARRY_POSE = {
|
||||
"shoulder_pan.pos": -23.0,
|
||||
"shoulder_lift.pos": 5.0,
|
||||
"elbow_flex.pos": 18.0,
|
||||
"wrist_flex.pos": -14.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
|
||||
# Cube-approach and carry poses are never jittered to preserve precision.
|
||||
_JITTER_LIMITS: dict[str, float] = {
|
||||
"shoulder_pan.pos": 5.0,
|
||||
"shoulder_lift.pos": 4.0,
|
||||
"elbow_flex.pos": 4.0,
|
||||
"wrist_flex.pos": 3.0,
|
||||
"wrist_roll.pos": 2.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
|
||||
"""Return a copy of pose with independent per-joint random perturbations."""
|
||||
return {
|
||||
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
|
||||
}
|
||||
|
||||
|
||||
def _random_stuck_pose(rng: np.random.Generator) -> dict:
|
||||
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
|
||||
|
||||
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
|
||||
table-safe envelope across the full sl range:
|
||||
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
|
||||
sl= 0 → ef ∈ [-25, 25] (mid reach)
|
||||
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
|
||||
wrist_flex is randomly offset from the horizontal value.
|
||||
"""
|
||||
pan = float(rng.uniform(-5.0, 35.0))
|
||||
sl = float(rng.uniform(-50.0, 30.0))
|
||||
|
||||
if sl <= 0.0:
|
||||
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
|
||||
ef_lo = alpha * -25.0 # 0 → -25
|
||||
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
|
||||
else:
|
||||
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
|
||||
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
|
||||
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
|
||||
|
||||
ef = float(rng.uniform(ef_lo, ef_hi))
|
||||
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf,
|
||||
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmxRecordGrabConfig:
|
||||
robot: RobotConfig
|
||||
dataset: DatasetRecordConfig
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Fraction of episodes that start from a random stuck pose (gripper closed) to
|
||||
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
|
||||
recovery_prob: float = 0.5
|
||||
|
||||
|
||||
def record_episode_spline(
|
||||
robot: OmxFollower,
|
||||
waypoints: list[dict],
|
||||
speeds: list[float],
|
||||
dataset: LeRobotDataset,
|
||||
task: str,
|
||||
) -> None:
|
||||
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
|
||||
|
||||
Segment durations are parameterized from the maximum absolute joint delta
|
||||
between consecutive waypoints divided by the requested segment speed,
|
||||
producing non-uniform timing in joint space. Interior tangents are derived
|
||||
from the adjacent per-segment velocities, with clamped (zero-velocity)
|
||||
endpoints so the arm starts and stops smoothly. Each segment is cubic
|
||||
Hermite, giving C1 continuity at every waypoint.
|
||||
"""
|
||||
pts = [pose_to_array(w) for w in waypoints]
|
||||
n = len(pts)
|
||||
|
||||
# Steps and duration per segment
|
||||
n_steps_list = []
|
||||
timestamps = []
|
||||
for i in range(n - 1):
|
||||
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
|
||||
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
|
||||
n_steps_list.append(ns)
|
||||
timestamps.append(ns / dataset.fps)
|
||||
|
||||
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
|
||||
vels = [np.zeros_like(pts[0])]
|
||||
for i in range(1, n - 1):
|
||||
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
|
||||
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
|
||||
vels.append(0.5 * (v_prev + v_next))
|
||||
vels.append(np.zeros_like(pts[0]))
|
||||
|
||||
dt = 1.0 / dataset.fps
|
||||
for seg in range(n - 1):
|
||||
ns = n_steps_list[seg]
|
||||
if ns == 0:
|
||||
continue
|
||||
p0, p1 = pts[seg], pts[seg + 1]
|
||||
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
|
||||
m0 = vels[seg] * timestamps[seg]
|
||||
m1 = vels[seg + 1] * timestamps[seg]
|
||||
|
||||
for step in range(1, ns + 1):
|
||||
t = step / ns
|
||||
h00 = 2 * t**3 - 3 * t**2 + 1
|
||||
h10 = t**3 - 2 * t**2 + t
|
||||
h01 = -2 * t**3 + 3 * t**2
|
||||
h11 = t**3 - t**2
|
||||
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||
|
||||
action = array_to_pose(commanded)
|
||||
robot.send_action(action)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
def record_grab_episode(
|
||||
robot: OmxFollower,
|
||||
dataset: LeRobotDataset,
|
||||
pan: float,
|
||||
t: float,
|
||||
task: str,
|
||||
recovery_start: bool = False,
|
||||
) -> None:
|
||||
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
|
||||
|
||||
Normal sequence (initial HOME move is NOT recorded):
|
||||
HOME → raised approach above cube → lower → close gripper
|
||||
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
|
||||
|
||||
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
|
||||
(gripper closed) without recording, then recording begins from there:
|
||||
stuck_pose → raised approach above cube → [normal grab sequence from there]
|
||||
|
||||
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
|
||||
"""
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
sl_raised = sl - GRAB_RAISE_SL_OFFSET
|
||||
wf_horizontal = horizontal_wrist_flex(sl, ef)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
if recovery_start:
|
||||
stuck_pose = _random_stuck_pose(rng)
|
||||
logger.info(f"Recovery start: {stuck_pose}")
|
||||
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
|
||||
first_waypoints = [stuck_pose]
|
||||
first_speeds = []
|
||||
else:
|
||||
jittery_start = _jitter_pose(HOME_POSE, rng)
|
||||
move_to_pose(robot, jittery_start, APPROACH_SPEED)
|
||||
first_waypoints = [jittery_start]
|
||||
first_speeds = []
|
||||
|
||||
waypoints = first_waypoints + [
|
||||
{ # raised approach: arm above cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # lower onto cube — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # close gripper — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
_jitter_pose(
|
||||
{ # raise with cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
_jitter_pose(
|
||||
{ # retract: fold arm toward HOME before sweeping to carry zone
|
||||
"shoulder_pan.pos": pan * 0.25,
|
||||
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
|
||||
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
GRAB_CARRY_POSE, # no jitter: target drop zone
|
||||
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
|
||||
HOME_POSE,
|
||||
]
|
||||
speeds = first_speeds + [
|
||||
RECORD_SPEED, # (HOME →) raised approach
|
||||
GRAB_LOWER_SPEED, # raised approach → lower
|
||||
GRAB_LOWER_SPEED, # lower → close gripper
|
||||
RECORD_SPEED, # close gripper → raise
|
||||
RECORD_SPEED, # raise → retract
|
||||
RECORD_SPEED, # retract → carry pose
|
||||
RECORD_SPEED, # carry pose → drop
|
||||
RECORD_SPEED, # drop → HOME
|
||||
]
|
||||
|
||||
record_episode_spline(robot, waypoints, speeds, dataset, task)
|
||||
|
||||
# Dwell at HOME for ~0.5 s before next episode
|
||||
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
|
||||
dt = 1.0 / dataset.fps
|
||||
for _ in range(int(dataset.fps * 0.5)):
|
||||
robot.send_action(HOME_POSE)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
dataset.add_frame({**obs_frame, **home_action, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger.info(pformat(cfg))
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
use_videos = cfg.dataset.video
|
||||
|
||||
teleop_action_processor, _, robot_obs_processor = make_default_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_obs_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
)
|
||||
|
||||
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
|
||||
dataset = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset.resume(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
else:
|
||||
cfg.dataset.stamp_repo_id()
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=use_videos,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
with VideoEncodingManager(dataset):
|
||||
for episode_idx in range(cfg.dataset.num_episodes):
|
||||
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
|
||||
|
||||
logger.info("Step 1: grabbing and placing cube...")
|
||||
grab_cube(robot)
|
||||
pan, t = place_cube(robot)
|
||||
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
|
||||
|
||||
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
|
||||
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
|
||||
record_grab_episode(
|
||||
robot,
|
||||
dataset,
|
||||
pan,
|
||||
t,
|
||||
cfg.dataset.single_task,
|
||||
recovery_start=recovery_start,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
logger.info(f"Episode {episode_idx + 1} saved.")
|
||||
|
||||
finally:
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
record_grab()
|
||||
267
examples/omx/reset_environment.py
Normal file
267
examples/omx/reset_environment.py
Normal file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-reset and cube-grab utility for the OMX robot arm.
|
||||
|
||||
Provides:
|
||||
- grab_cube(robot): sweep workspace, center cube, close gripper
|
||||
- place_cube(robot): carry cube to a random position, release
|
||||
|
||||
Standalone usage (run from repo root):
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
|
||||
|
||||
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
|
||||
|
||||
To read current joint values for calibration, add after robot.connect():
|
||||
obs = robot.get_observation()
|
||||
print({k: round(obs[k], 1) for k in JOINT_NAMES})
|
||||
robot.disconnect(); raise SystemExit
|
||||
|
||||
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
|
||||
Linear interpolation preserves this constraint between any two poses that satisfy it.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Poses ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
HOME_POSE = {
|
||||
"shoulder_pan.pos": 0.0,
|
||||
"shoulder_lift.pos": -50.0,
|
||||
"elbow_flex.pos": 50.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
SWEEP_WAYPOINTS = [
|
||||
{
|
||||
"shoulder_pan.pos": -60.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -20.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": -30.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": 20.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -55.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
]
|
||||
|
||||
# ── Motion parameters ─────────────────────────────────────────────────────────
|
||||
|
||||
CONTROL_HZ = 30
|
||||
APPROACH_SPEED = 50.0
|
||||
SWEEP_SPEED = 40.0
|
||||
|
||||
# ── Grab-sequence parameters ──────────────────────────────────────────────────
|
||||
|
||||
GRAB_PAN = 0.0
|
||||
SWEEP_LEFT_PAN = -60.0
|
||||
SWEEP_RIGHT_PAN = 60.0
|
||||
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
|
||||
SWEEP_END_PAN_RANGE = (15.0, 20.0)
|
||||
|
||||
SWEEP_LOW_SHOULDER_LIFT = 50.0
|
||||
SWEEP_LOW_ELBOW_FLEX_START = -60.0
|
||||
SWEEP_LOW_ELBOW_FLEX_END = -55.0
|
||||
|
||||
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
|
||||
|
||||
PUSH_START_SHOULDER_LIFT = 0.0
|
||||
PUSH_START_ELBOW_FLEX = 45.0
|
||||
PUSH_END_SHOULDER_LIFT = 50.0
|
||||
PUSH_END_ELBOW_FLEX = -50.0
|
||||
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
|
||||
# Does not affect the grab-target interpolation in record_grab.py.
|
||||
PUSH_RAISE_OFFSET = 5.0
|
||||
|
||||
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
|
||||
GRIPPER_CLOSE_POS = 50.0
|
||||
|
||||
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
|
||||
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
|
||||
|
||||
JOINT_NAMES = [
|
||||
"shoulder_pan.pos",
|
||||
"shoulder_lift.pos",
|
||||
"elbow_flex.pos",
|
||||
"wrist_flex.pos",
|
||||
"wrist_roll.pos",
|
||||
"gripper.pos",
|
||||
]
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def pose_to_array(pose: dict) -> np.ndarray:
|
||||
return np.array([pose[k] for k in JOINT_NAMES])
|
||||
|
||||
|
||||
def array_to_pose(arr: np.ndarray) -> dict:
|
||||
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
|
||||
|
||||
|
||||
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
|
||||
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
|
||||
|
||||
|
||||
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
|
||||
sl = SWEEP_LOW_SHOULDER_LIFT
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
|
||||
def _high_sweep_pose(pan: float) -> dict:
|
||||
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
|
||||
|
||||
|
||||
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": shoulder_lift,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": gripper,
|
||||
}
|
||||
|
||||
|
||||
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
|
||||
"""Interpolate from current position to target at the given speed (units/s)."""
|
||||
obs = robot.get_observation()
|
||||
current = np.array([obs[k] for k in JOINT_NAMES])
|
||||
goal = pose_to_array(target)
|
||||
|
||||
max_distance = float(np.max(np.abs(goal - current)))
|
||||
if max_distance < 0.5:
|
||||
return
|
||||
|
||||
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
|
||||
dt = 1.0 / CONTROL_HZ
|
||||
for step in range(1, n_steps + 1):
|
||||
t = step / n_steps
|
||||
robot.send_action(array_to_pose(current + t * (goal - current)))
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
# ── Sequences ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def grab_cube(robot: Robot) -> None:
|
||||
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
for pan, end_pan in [
|
||||
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
|
||||
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
|
||||
]:
|
||||
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
|
||||
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
logger.info("Extending to push cube into gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
|
||||
SWEEP_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Closing gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Grab complete.")
|
||||
|
||||
|
||||
def place_cube(robot: Robot) -> tuple[float, float]:
|
||||
"""Carry the cube (gripper closed) to a random position on the left side, then release.
|
||||
|
||||
Returns:
|
||||
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
|
||||
"""
|
||||
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
|
||||
t = float(np.random.uniform(*PLACE_REACH_RANGE))
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
|
||||
|
||||
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
logger.info("Place complete.")
|
||||
return pan, t
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
|
||||
parser.add_argument("--port", default="/dev/ttyACM1")
|
||||
parser.add_argument("--robot_id", default="omx_follower")
|
||||
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
try:
|
||||
if args.mode == "grab":
|
||||
grab_cube(robot)
|
||||
elif args.mode == "grab_and_place":
|
||||
grab_cube(robot)
|
||||
place_cube(robot)
|
||||
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -143,43 +151,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -190,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -65,14 +65,15 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to EE action
|
||||
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
@@ -94,7 +95,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
# Build pipeline to convert EE action to joints action (IK).
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
@@ -107,7 +108,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
# Build pipeline to convert joint observation to EE observation (FK).
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -118,13 +119,12 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
@@ -163,14 +163,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -182,13 +182,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
126
examples/phone_to_so100/rollout.py
Normal file
126
examples/phone_to_so100/rollout.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
|
||||
|
||||
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
|
||||
with phone teleoperation in EE space, so at deployment we only need the
|
||||
joint↔EE conversion on the robot side; the phone is not used.
|
||||
|
||||
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
|
||||
(inline policy call). For recording during rollout, switch to Sentry,
|
||||
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Peek at motor names once to build the kinematic solver.
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,673 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
For simulation environments, see eval_with_simulation.py
|
||||
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot-data-collection/folding_final \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.left_arm_config.can_interface=socketcan \
|
||||
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.left_arm_config.max_relative_target=8.0 \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.right_arm_config.can_interface=socketcan \
|
||||
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.right_arm_config.max_relative_target=8.0 \
|
||||
--task="Fold the T-shirt properly" \
|
||||
--fps=30 \
|
||||
--duration=2000 \
|
||||
--interpolation_multiplier=3 \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--device=cuda
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: Tensor,
|
||||
current_state: Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> Tensor:
|
||||
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
||||
|
||||
When a policy uses relative actions, the RTC prefix (leftover actions from
|
||||
the previous chunk) is stored in absolute space. Before feeding it back to
|
||||
the policy we need to re-express it relative to the *current* robot state
|
||||
and then re-normalize.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
||||
# not the full pos/vel/torque state the robot exposes.
|
||||
observation_features_hw = {
|
||||
key: value
|
||||
for key, value in robot.observation_features().items()
|
||||
if key.endswith(".pos") or isinstance(value, tuple)
|
||||
}
|
||||
|
||||
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
# The stats are embedded in the processor .safetensors files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if relative_step is not None:
|
||||
if relative_step.action_names is None:
|
||||
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
relative_step.action_names = [
|
||||
k for k in robot.robot.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Re-anchor leftover actions for relative-action policies.
|
||||
# We need the *postprocessed* (absolute) leftover, not the original
|
||||
# (normalized/relative) one that get_left_over() returns.
|
||||
if (
|
||||
prev_actions is not None
|
||||
and relative_step is not None
|
||||
and OBS_STATE in obs_with_policy_features
|
||||
):
|
||||
with action_queue.lock:
|
||||
if action_queue.queue is not None:
|
||||
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
||||
else:
|
||||
prev_actions_abs = None
|
||||
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_abs,
|
||||
current_state=obs_with_policy_features[OBS_STATE],
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
||||
|
||||
action_count = 0
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_pretrained_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -143,43 +151,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -190,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -62,21 +62,20 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
# Build pipeline to convert follower joints to EE observation.
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -87,7 +86,7 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
# Build pipeline to convert leader joints to EE action.
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -98,9 +97,9 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
# Build pipeline to convert EE action to follower joints (with safety bounds).
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
steps=[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
@@ -115,13 +114,12 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
@@ -144,7 +142,7 @@ def main():
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
init_rerun(session_name="recording_so100_ee")
|
||||
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
@@ -160,14 +158,14 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -179,13 +177,13 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
134
examples/so100_to_so100_EE/rollout.py
Normal file
134
examples/so100_to_so100_EE/rollout.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). The custom observation/action processors convert between
|
||||
joint space (robot hardware) and end-effector space (policy I/O) via
|
||||
forward/inverse kinematics.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Kinematic solver: we need the motor-name list, so peek at the robot once.
|
||||
# (The rollout engine owns the connected instance; we only use this for introspection.)
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
# Joint-space observation → EE-space observation (consumed by the policy).
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# EE-space action (produced by the policy) → joint-space action (sent to robot).
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Policy config (full model is loaded inside build_rollout_context).
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
|
||||
# otherwise skip the joint↔EE conversion and the policy would receive the
|
||||
# wrong observation/action space.
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,13 +4,13 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -28,7 +28,7 @@ def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
policy_learner: GaussianActorPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
@@ -40,8 +40,9 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -83,24 +84,26 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
stats = algorithm.update(batch_iter())
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -113,7 +116,7 @@ def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
policy_actor: GaussianActorPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
@@ -144,15 +147,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
@@ -261,14 +264,14 @@ def main():
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
policy_cfg = GaussianActorConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||
|
||||
|
||||
def main():
|
||||
@@ -22,10 +22,10 @@ def main():
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
# Make reward model, preprocessor, and optimizer
|
||||
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
@@ -42,7 +42,7 @@ def main():
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
loss, output_dict = reward_model.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -58,8 +58,8 @@ def main():
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
# You can now save the trained reward model.
|
||||
reward_model.push_to_hub(classifier_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
|
||||
dependencies = [
|
||||
# Core ML
|
||||
"torch>=2.7,<2.11.0",
|
||||
"torchvision>=0.22.0,<0.26.0",
|
||||
"torch>=2.7,<2.12.0",
|
||||
"torchvision>=0.22.0,<0.27.0",
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
@@ -95,11 +95,22 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
||||
|
||||
# NOTE: torchcodec wheel availability matrix (PyPI):
|
||||
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
|
||||
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
|
||||
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
|
||||
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
|
||||
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
|
||||
|
||||
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
|
||||
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
||||
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
|
||||
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
]
|
||||
training = [
|
||||
@@ -127,8 +138,10 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -140,6 +153,8 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
|
||||
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
@@ -163,6 +178,9 @@ unitree_g1 = [
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
@@ -194,12 +212,25 @@ groot = [
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). vllm is the preferred backend
|
||||
# on Linux, with a transformers fallback elsewhere; openai is the default
|
||||
# backend and talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted endpoints). Distributed execution is
|
||||
# delegated to Hugging Face Jobs (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -248,6 +279,7 @@ all = [
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[openarms]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[rebot]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[diffusion]",
|
||||
@@ -289,10 +321,26 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
|
||||
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
|
||||
# uv pip install --force-reinstall torch torchvision \
|
||||
# --index-url https://download.pytorch.org/whl/cu130
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
15
src/lerobot/annotations/__init__.py
Normal file
15
src/lerobot/annotations/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
50
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
50
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .vocabulary import (
|
||||
VOCABULARY_FILENAME,
|
||||
Vocabulary,
|
||||
VocabularyDiscoveryModule,
|
||||
load_vocabulary,
|
||||
save_vocabulary,
|
||||
vocabulary_path,
|
||||
)
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"VOCABULARY_FILENAME",
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
"Vocabulary",
|
||||
"VocabularyDiscoveryModule",
|
||||
"load_vocabulary",
|
||||
"save_vocabulary",
|
||||
"vocabulary_path",
|
||||
]
|
||||
251
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
251
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
@@ -0,0 +1,251 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabularyConfig:
|
||||
"""Phase 0 — dataset-level canonical vocabulary discovery.
|
||||
|
||||
Watches the first ``sample_episodes`` episode videos and asks the VLM
|
||||
to derive a small canonical vocabulary (subtask labels + memory
|
||||
milestones) that every episode in the dataset will reuse. The VLM
|
||||
decides the count itself from what it sees in the clips — short
|
||||
pick-and-place demos get ~6 labels, longer multi-step recipes more.
|
||||
The output lands at ``meta/canonical_vocabulary.json`` and feeds
|
||||
phase 1's subtask + memory generation as both a prompt-side
|
||||
constraint and a post-VLM validation gate.
|
||||
|
||||
Why this exists: free-form LLM rephrasing per episode produces near-
|
||||
unique subtask strings, which makes the downstream low-level policy's
|
||||
conditioning effectively noise — at inference the policy generates a
|
||||
*new* paraphrase the action expert has never seen and produces tiny
|
||||
cautious actions. Forcing every episode onto the same small set of
|
||||
canonical strings gives the action expert dense supervision per
|
||||
string and a small target distribution to learn against.
|
||||
|
||||
Set ``enabled=False`` to fall back to free-form generation (original
|
||||
behaviour). ``reuse_existing=True`` keeps a hand-edited vocabulary
|
||||
file from being clobbered on re-runs.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
sample_episodes: int = 3
|
||||
max_video_frames_per_episode: int = 32
|
||||
# When True (default), an existing meta/canonical_vocabulary.json is
|
||||
# loaded as-is and no VLM call is made — lets operators hand-edit the
|
||||
# file. Set False to always rediscover from the sample episodes.
|
||||
reuse_existing: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: plan + subtasks + memory + task augmentation.
|
||||
|
||||
The ``plan`` module attaches the whole episode as one Qwen-VL video
|
||||
block; ``max_video_frames`` only caps the frames packed in (a
|
||||
model-capacity bound, not an annotation-logic knob).
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Number of ``task_aug`` rephrasings emitted at ``t=0``. The renderer's
|
||||
# ``${task}`` binding rotates among them per ``sample_idx``. ``0`` disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# When to derive the task from the video instead of using
|
||||
# ``record.episode_task``: ``off``, ``if_short`` (short / placeholder /
|
||||
# missing canonical task), or ``always``. The derived task replaces the
|
||||
# canonical one for every ``plan``-module prompt; ``meta/tasks.parquet``
|
||||
# is never modified.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# Frame sampling for the subtask-decomposition prompt.
|
||||
frames_per_second: float = 1.0
|
||||
max_video_frames: int = 128
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# When True (and backend supports it, e.g. ``openai``), the ``plan``
|
||||
# module sends a ``video_url`` block pointing at a per-episode mp4
|
||||
# subclip and lets the server sample frames at ``use_video_url_fps``.
|
||||
use_video_url: bool = False
|
||||
use_video_url_fps: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each interjection emits a paired ``(interjection, speech)`` event row
|
||||
# and triggers a ``plan`` refresh at the same timestamp via the
|
||||
# ``plan`` module.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Visual context attached to the interjection prompt: a short window
|
||||
# of frames centered on the chosen timestamp so the VLM sees the
|
||||
# ongoing motion rather than a single frozen frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""How many *consecutive* frames each emission tick anchors a VQA pair
|
||||
to. The VLM grounds its answer (bbox / keypoint coordinates, count, …)
|
||||
against the *first* anchored frame's image, so anchoring K>1 frames
|
||||
copies that same answer onto later frames where the scene has already
|
||||
moved — stale labels. Default ``1``: a VQA pair lands on exactly its
|
||||
emission frame, no temporal smear. Raise it only to trade label
|
||||
precision for more (noisier) VQA frames."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests).
|
||||
# ``openai`` talks to a local OpenAI-compatible server; the CLI
|
||||
# auto-spawns one when ``auto_serve=True``.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-35B-A3B-FP8"
|
||||
|
||||
# OpenAI-compatible server endpoint; ``EMPTY`` works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# When True with ``backend=openai``, the CLI probes ``api_base`` and
|
||||
# spawns a server if none answers (default: ``transformers serve``).
|
||||
# Set to False to fail fast when pointing at a remote endpoint.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command. ``{port}`` is substituted per replica
|
||||
# when ``parallel_servers > 1``.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Run multiple independent inference servers for round-robin client
|
||||
# routing (each pinned to a GPU via ``CUDA_VISIBLE_DEVICES`` and bound
|
||||
# to ``serve_port + i``). ``num_gpus=0`` means one GPU per replica.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
json_mode: bool = True
|
||||
batch_size: int = 4
|
||||
tensor_parallel_size: int = 1
|
||||
|
||||
# Fraction of GPU memory vllm allocates for weights + KV cache.
|
||||
gpu_memory_utilization: float = 0.9
|
||||
# Cap context length (None = model default). On 80 GB H100 a 30B BF16
|
||||
# model often needs <= 8192 to leave KV-cache headroom.
|
||||
max_model_len: int | None = None
|
||||
trust_remote_code: bool = False
|
||||
|
||||
# Override the camera stream used for keyframe attachment. None picks
|
||||
# the first ``observation.images.*`` key the dataset declares.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as ``extra_body.chat_template_kwargs`` on every chat call;
|
||||
# use to pass model-specific flags such as ``{"enable_thinking": false}``.
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotation/run_hf_job.py``); this config only controls
|
||||
intra-process episode concurrency.
|
||||
"""
|
||||
|
||||
# Episodes processed concurrently within each module phase. Each
|
||||
# in-flight episode dispatches 3-5 dependent VLM calls, so this is the
|
||||
# main knob for saturating ``parallel_servers`` and ``client_concurrency``.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate``.
|
||||
|
||||
The writer rewrites ``data/chunk-*/file-*.parquet`` in place. Multiple
|
||||
revisions of the same dataset live in separate copies.
|
||||
"""
|
||||
|
||||
# Hub dataset id. Used as the download source when ``root`` is unset,
|
||||
# and as the destination repo when ``push_to_hub`` is enabled and
|
||||
# ``dest_repo_id`` is unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Optional separate Hub dataset id to push the annotated result to. When
|
||||
# unset, ``push_to_hub`` uploads back to ``repo_id`` (annotate in place);
|
||||
# when set, the source ``repo_id`` is left untouched.
|
||||
dest_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/`` when unset.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend. When unset, the pipeline decodes with the
|
||||
# ffmpeg CLI: it decodes AV1 and runs each decode as an isolated child
|
||||
# process, which is both crash-safe and safe under the concurrent
|
||||
# decode the executor performs (torchcodec is not thread-safe and
|
||||
# SIGSEGVs there). Set to ``"torchcodec"`` or ``"pyav"`` to pin an
|
||||
# in-process decoder when its build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
|
||||
# When True, upload the annotated dataset to the Hugging Face Hub:
|
||||
# to ``dest_repo_id`` if set, otherwise back to ``repo_id``. One of
|
||||
# the two must be set for this to take effect.
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
322
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
322
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
@@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor plans **seven phases** in the dependency order from the plan:
|
||||
|
||||
phase 0: vocabulary discovery — derive a small canonical vocabulary
|
||||
from the first few sample-episode videos (subtask labels +
|
||||
memory milestones) and persist it next to the dataset; the
|
||||
``plan`` module then constrains every per-episode generation
|
||||
to those strings, so the downstream policy sees a small,
|
||||
repeatable conditioning distribution
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
vocabulary: Any = None # VocabularyDiscoveryModule | None
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 0: vocabulary discovery. Mutates ``self.plan.vocabulary``
|
||||
# so subsequent per-episode plan calls see the canonical labels.
|
||||
phases.append(self._run_vocabulary_phase(records, root))
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_vocabulary_phase(
|
||||
self, records: list[EpisodeRecord], root: Path
|
||||
) -> PhaseResult:
|
||||
"""Discover (or load) the canonical vocabulary, wire it into ``self.plan``.
|
||||
|
||||
Returns a ``PhaseResult`` whose ``episodes_processed`` is the number
|
||||
of sample episodes consulted (0 when disabled or no VLM call was
|
||||
needed); ``episodes_skipped`` is always ``0`` because vocabulary is
|
||||
a once-per-dataset artifact, not a per-episode product.
|
||||
"""
|
||||
from .vocabulary import load_vocabulary, save_vocabulary # noqa: PLC0415
|
||||
|
||||
if self.vocabulary is None or not getattr(self.vocabulary, "enabled", False):
|
||||
print(
|
||||
"[annotate] phase=vocabulary skipped (module disabled or unset)",
|
||||
flush=True,
|
||||
)
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
existing = load_vocabulary(root)
|
||||
if existing is not None and self.config.vocabulary.reuse_existing:
|
||||
print(
|
||||
f"[annotate] phase=vocabulary reusing {root / 'meta' / 'canonical_vocabulary.json'} "
|
||||
f"({len(existing.subtasks)} subtask labels, "
|
||||
f"{len(existing.memory_milestones)} memory milestones)",
|
||||
flush=True,
|
||||
)
|
||||
self.plan.vocabulary = existing
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
sample_n = max(1, min(int(self.config.vocabulary.sample_episodes), len(records)))
|
||||
print(
|
||||
f"[annotate] phase=vocabulary discovering from {sample_n} sample episode(s)...",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
vocab = self.vocabulary.discover(records[:sample_n], existing=existing)
|
||||
if vocab is None:
|
||||
print(
|
||||
"[annotate] phase=vocabulary returned no vocabulary — "
|
||||
"plan module will fall back to free-form generation",
|
||||
flush=True,
|
||||
)
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
save_path = save_vocabulary(root, vocab)
|
||||
print(
|
||||
f"[annotate] phase=vocabulary wrote {save_path} "
|
||||
f"({len(vocab.subtasks)} subtask labels, "
|
||||
f"{len(vocab.memory_milestones)} memory milestones) in "
|
||||
f"{time.time() - t0:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
self.plan.vocabulary = vocab
|
||||
return PhaseResult(name="vocabulary", episodes_processed=sample_n, episodes_skipped=0)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(
|
||||
name="plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||
)
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
483
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
483
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
@@ -0,0 +1,483 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
|
||||
# concurrency- and crash-safe default for the pipeline's threaded
|
||||
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
|
||||
# decoder when the build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# ``camera_keys`` covers both image- and video-stored cameras and is
|
||||
# always defined on the metadata (``[]`` in the worst case), so it is
|
||||
# the single source we need here.
|
||||
keys = list(self._meta.camera_keys)
|
||||
# Last-resort fallback: if metadata didn't surface anything but the
|
||||
# caller explicitly named a camera (``--vlm.camera_key=...``), trust
|
||||
# them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips
|
||||
re-extract when the cached clip already exists. Re-encodes to
|
||||
H.264 (libx264) so the resulting mp4 is decodable by every
|
||||
downstream video processor — stream-copy would inherit the
|
||||
source codec (often AV1 in modern LeRobot datasets), which
|
||||
vllm's libav build cannot decode.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, timeout=300)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
|
||||
Returns one frame per requested timestamp, or ``[]`` if decoding
|
||||
failed wholesale — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
|
||||
# ThreadPoolExecutor and the in-process decoders are unsafe there:
|
||||
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
|
||||
# (a crash no try/except can catch), PyAV can likewise segfault on
|
||||
# AV1, and lerobot's ``pyav`` backend routes through the removed
|
||||
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
|
||||
# out per frame: each decode is an isolated child process, so it is
|
||||
# both crash-safe and concurrency-safe. ``video_backend`` can pin
|
||||
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
|
||||
# build is safe.
|
||||
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
|
||||
|
||||
exc: Exception | None = None
|
||||
for backend in chain:
|
||||
try:
|
||||
if backend == "ffmpeg":
|
||||
return _decode_frames_ffmpeg(video_path, shifted)
|
||||
if backend in ("pyav", "av"):
|
||||
return _decode_frames_av(video_path, shifted)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as e: # noqa: PERF203
|
||||
exc = e
|
||||
|
||||
# Every backend raised. Log loudly the first time so a silent
|
||||
# vqa-module no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = getattr(self, "_warned_decode_fail", False)
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s "
|
||||
"video_path=%s backends=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
chain,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
|
||||
|
||||
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
|
||||
piping a single PNG to stdout. Unlike the in-process decoders this
|
||||
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
|
||||
modern LeRobot datasets use) where torchcodec raises and PyAV can
|
||||
SIGSEGV, and a crash stays isolated to the child process — a non-zero
|
||||
exit is a catchable error, not a segfault of the whole job. Returns one
|
||||
``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import io # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
import numpy as np # noqa: PLC0415
|
||||
|
||||
frames: list[Any] = []
|
||||
for ts in timestamps:
|
||||
proc = subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-nostdin", "-loglevel", "error",
|
||||
"-ss", f"{max(ts, 0.0):.3f}",
|
||||
"-i", str(video_path),
|
||||
"-frames:v", "1",
|
||||
"-f", "image2pipe", "-vcodec", "png", "pipe:1",
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=120,
|
||||
)
|
||||
if not proc.stdout:
|
||||
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
|
||||
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
|
||||
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
|
||||
return frames
|
||||
|
||||
|
||||
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||
|
||||
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
|
||||
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
|
||||
fallback; this stays available behind ``video_backend="pyav"``. Returns
|
||||
one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# Seek to the keyframe at or before the first requested timestamp.
|
||||
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||
for idx, frame in enumerate(container.decode(stream)):
|
||||
ts = frame.time
|
||||
if ts is None:
|
||||
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||
loaded_ts.append(ts)
|
||||
loaded_frames.append(
|
||||
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||
)
|
||||
if ts >= last_ts:
|
||||
break
|
||||
if not loaded_frames:
|
||||
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||
ts_tensor = torch.tensor(loaded_ts)
|
||||
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not getattr(self, "_warned_no_camera", False):
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
"""
|
||||
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(
|
||||
record, [frame_timestamp], camera_key=camera_key
|
||||
)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("module_2_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("module_2_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(0.0, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -0,0 +1,617 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..vocabulary import Vocabulary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
vocabulary: Vocabulary | None = None
|
||||
"""When set, the module constrains subtask + memory generation to the
|
||||
canonical strings in ``vocabulary``. Phase 0 (vocabulary discovery)
|
||||
populates this once per dataset; ``None`` falls back to free-form
|
||||
generation (original behaviour)."""
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Resolve the task that drives every other ``plan``-module prompt.
|
||||
# May be the canonical ``record.episode_task`` (default), or a fresh
|
||||
# description derived from the video when the canonical task is
|
||||
# empty / placeholder / forced-off (see PlanConfig.derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
|
||||
# message renderer rotates ``${task}`` deterministically through
|
||||
# them so the policy sees diverse phrasings during training.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
if self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
# Always include the effective task itself as the first variant
|
||||
# so the rotation is guaranteed to cover the source-of-truth
|
||||
# phrasing, not just synthetic alternatives.
|
||||
seen: set[str] = set()
|
||||
ordered = [effective_task, *rephrasings]
|
||||
for phrasing in ordered:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": key,
|
||||
"style": "task_aug",
|
||||
"timestamp": t0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary — including t=0 (start of
|
||||
# the first subtask). Because the plan is just a numbered list
|
||||
# of *still-todo* subtasks, re-emitting at each boundary makes
|
||||
# the active plan shrink as work progresses: at frame t the
|
||||
# rendered ``${plan}`` is the most recent emission, which
|
||||
# contains exactly the subtasks that started at or after the
|
||||
# current span. Saves the runtime from having to derive
|
||||
# "what's still left" at inference time.
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start
|
||||
prior_memory = ""
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers (factored out: every ``plan``-module prompt below follows
|
||||
# the same "build messages → single VLM call → pull a named field"
|
||||
# shape, only differing in field name + post-processing).
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
|
||||
"""User message combining the episode video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||
return (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections — subtask generation
|
||||
runs ~1 Hz at inference, but plan re-emission is event-driven.
|
||||
Now also forwards the interjection's own text into the prompt so
|
||||
the refreshed plan can actually reflect the user's correction
|
||||
(the previous version told the model "an interjection happened"
|
||||
without telling it what the user said).
|
||||
"""
|
||||
existing = staging.read("plan")
|
||||
# Pass the episode's last frame timestamp so the final subtask
|
||||
# span is closed (otherwise its ``end`` equals its ``start``,
|
||||
# zero duration, and the "current subtask at refresh_t" lookup
|
||||
# in ``_generate_plan`` misses any refresh that lands inside it).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
vocabulary_block=self._subtask_vocabulary_block(),
|
||||
)
|
||||
messages = self._video_message(record, prompt)
|
||||
spans = self._vlm_field(messages, "subtasks")
|
||||
# When a vocabulary is in force, do a single targeted retry if
|
||||
# any returned subtask is off-vocab — strict exact-match only,
|
||||
# no fuzzy snapping. The retry includes the offending strings
|
||||
# and the full canonical list so the VLM can correct itself.
|
||||
if self.vocabulary is not None and self.vocabulary.subtasks and spans:
|
||||
invalid = self._invalid_subtasks(spans)
|
||||
if invalid:
|
||||
logger.info(
|
||||
"episode %d: VLM emitted %d off-vocab subtask(s) (%s); retrying once",
|
||||
record.episode_index,
|
||||
len(invalid),
|
||||
invalid,
|
||||
)
|
||||
retry_msg = self._build_subtask_retry_message(messages, invalid)
|
||||
retried = self._vlm_field(retry_msg, "subtasks")
|
||||
if retried:
|
||||
spans = retried
|
||||
|
||||
if not spans:
|
||||
return []
|
||||
# clamp to [t0, t_last] and sort
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(t0, min(start, t_last))
|
||||
end = max(t0, min(end, t_last))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
text = self._canonicalize_subtask(text)
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
cleaned = self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
if self.vocabulary is not None and self.vocabulary.subtasks and not cleaned:
|
||||
logger.warning(
|
||||
"episode %d: every VLM subtask was off-vocab even after retry — "
|
||||
"episode left empty (extend meta/canonical_vocabulary.json to "
|
||||
"cover the missing phase)",
|
||||
record.episode_index,
|
||||
)
|
||||
return cleaned
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Canonical-vocabulary helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _subtask_vocabulary_block(self) -> str:
|
||||
"""Bullet-list of canonical subtasks the VLM must pick from.
|
||||
|
||||
Returns an empty string when no vocabulary is configured —
|
||||
``module_1_subtasks.txt`` then falls back to its free-form
|
||||
rules (original behaviour).
|
||||
"""
|
||||
if self.vocabulary is None or not self.vocabulary.subtasks:
|
||||
return ""
|
||||
bullets = "\n".join(f"- {s}" for s in self.vocabulary.subtasks)
|
||||
return (
|
||||
"You MUST choose each subtask label verbatim from this canonical "
|
||||
"vocabulary — pick the closest match for each phase of the demo, "
|
||||
"and reuse the SAME string every time that phase recurs. The "
|
||||
"low-level policy is conditioned on these exact strings; any "
|
||||
"novel paraphrase you invent will make its conditioning OOD.\n"
|
||||
"Canonical subtask labels:\n"
|
||||
f"{bullets}\n\n"
|
||||
)
|
||||
|
||||
def _memory_vocabulary_block(self) -> str:
|
||||
"""Bullet-list of canonical memory milestones the VLM must pick from."""
|
||||
if self.vocabulary is None or not self.vocabulary.memory_milestones:
|
||||
return ""
|
||||
bullets = "\n".join(f"- {m}" for m in self.vocabulary.memory_milestones)
|
||||
return (
|
||||
"Compose the memory by picking ONLY from this canonical milestone "
|
||||
"list — append a milestone (or rewrite the running memory to "
|
||||
"compress past ones) using these exact phrases. Do not invent new "
|
||||
"wording: every paraphrase weakens the downstream conditioning.\n"
|
||||
"Canonical memory milestones:\n"
|
||||
f"{bullets}\n\n"
|
||||
)
|
||||
|
||||
_NORMALIZE_STRIP_TOKENS: frozenset[str] = frozenset({"the", "a", "an"})
|
||||
|
||||
def _canonicalize_subtask(self, text: str) -> str:
|
||||
"""Validate ``text`` against the canonical vocabulary; no fuzzy snap.
|
||||
|
||||
Without a vocabulary, the original text passes through. With a
|
||||
vocabulary, accept the span only if its normalised form (lower-
|
||||
cased, articles stripped, whitespace collapsed) matches a
|
||||
canonical entry exactly — the canonical wording is returned so
|
||||
the supervised string is byte-identical across episodes.
|
||||
|
||||
Off-vocab spans are dropped (empty string). Upstream
|
||||
``_generate_subtasks`` triggers a targeted retry before reaching
|
||||
the drop path; this function never snaps or warps a span into
|
||||
a different label.
|
||||
"""
|
||||
if self.vocabulary is None or not self.vocabulary.subtasks:
|
||||
return text.strip()
|
||||
normalised = self._normalize(text)
|
||||
if not normalised:
|
||||
return ""
|
||||
for candidate in self.vocabulary.subtasks:
|
||||
if self._normalize(candidate) == normalised:
|
||||
return candidate
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _normalize(cls, text: str) -> str:
|
||||
"""Lowercase, strip articles, collapse whitespace, drop punctuation."""
|
||||
words = [
|
||||
w.strip(".,:;\"'!?()")
|
||||
for w in text.lower().replace(",", " ").split()
|
||||
]
|
||||
return " ".join(w for w in words if w and w not in cls._NORMALIZE_STRIP_TOKENS)
|
||||
|
||||
def _invalid_subtasks(self, spans: list[dict[str, Any]]) -> list[str]:
|
||||
"""Return the unique off-vocab subtask strings the VLM produced."""
|
||||
seen: list[str] = []
|
||||
for span in spans:
|
||||
text = str((span or {}).get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if self._canonicalize_subtask(text):
|
||||
continue
|
||||
if text not in seen:
|
||||
seen.append(text)
|
||||
return seen
|
||||
|
||||
def _build_subtask_retry_message(
|
||||
self, original_messages: list[dict[str, Any]], invalid: list[str]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Compose a one-shot correction prompt naming the off-vocab strings."""
|
||||
assert self.vocabulary is not None
|
||||
canonical = "\n".join(f"- {s}" for s in self.vocabulary.subtasks)
|
||||
invalid_list = "\n".join(f"- {s!r}" for s in invalid)
|
||||
correction = (
|
||||
"Your previous response included subtask labels that are NOT in "
|
||||
"the canonical vocabulary:\n"
|
||||
f"{invalid_list}\n\n"
|
||||
"Re-emit the same segmentation (same number of spans, same start/end "
|
||||
"timestamps where they were valid) but replace every off-vocab "
|
||||
"label with the EXACT canonical string for that phase, copied "
|
||||
"verbatim from this list:\n"
|
||||
f"{canonical}\n\n"
|
||||
"Strict rules:\n"
|
||||
"- Output strings must be byte-for-byte identical to entries above.\n"
|
||||
"- No articles, no adverbs, no extra words.\n"
|
||||
"- If a phase truly has no canonical match, omit that span entirely.\n"
|
||||
"Return the same JSON shape as before."
|
||||
)
|
||||
# Append the correction as an additional user turn; the model
|
||||
# sees the original prompt + its prior output is implied by the
|
||||
# conversation context (the VLM client is stateless, so we
|
||||
# re-send the original content plus this correction).
|
||||
retry_messages = [
|
||||
{
|
||||
"role": m.get("role", "user"),
|
||||
"content": (
|
||||
m.get("content")
|
||||
if isinstance(m.get("content"), str)
|
||||
else list(m.get("content") or [])
|
||||
),
|
||||
}
|
||||
for m in original_messages
|
||||
]
|
||||
retry_messages.append({"role": "user", "content": correction})
|
||||
return retry_messages
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
Previously this called the VLM with a prompt that asked it to
|
||||
compress the subtasks into a "compact hierarchical plan". That
|
||||
produced longer-than-necessary plans, cost an extra VLM round-trip
|
||||
per episode (plus one per interjection on refresh), and could
|
||||
diverge from the actual subtask sequence the model is going to
|
||||
execute. Replacing it with a plain summarisation keeps the plan
|
||||
tightly aligned with the upcoming subtasks and removes the VLM
|
||||
call entirely.
|
||||
|
||||
Layout — short imperative fragments prefixed by "N. ":
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
...
|
||||
|
||||
On a refresh at ``refresh_t`` (called from ``run_plan_updates``
|
||||
on interjection events, and from ``run_episode`` at every subtask
|
||||
boundary), only subtasks whose start is at or after ``refresh_t``
|
||||
are included — the plan shrinks as work progresses, so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s
|
||||
for s in subtask_spans
|
||||
if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(
|
||||
f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1)
|
||||
)
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("module_1_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
vocabulary_block=self._memory_vocabulary_block(),
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -0,0 +1,53 @@
|
||||
You are inspecting {n_episodes} sample episode video(s) from a teleoperated
|
||||
robot dataset. Every episode in the dataset performs the SAME task; the
|
||||
user originally asked: "{episode_task}".
|
||||
|
||||
Watch all the clips and produce a SHORT canonical vocabulary that every
|
||||
episode in this dataset will reuse. The downstream low-level policy is
|
||||
conditioned on these strings — duplicate phrasings (e.g. "grasp blue
|
||||
cube" vs "pick up the blue cube") would destroy the conditioning, so
|
||||
pick one wording per concept and reuse it everywhere.
|
||||
|
||||
Decide how many entries each list needs YOURSELF based on what you see —
|
||||
the smallest set that still covers every recurring phase in the demos.
|
||||
A simple two-object pick-and-place might need ~6 subtask labels and 2
|
||||
memory milestones; a long multi-step recipe needs more. Err on the side
|
||||
of FEWER — extra entries that don't recur across episodes weaken the
|
||||
conditioning.
|
||||
|
||||
You output two lists:
|
||||
|
||||
1. `subtasks`: imperative, telegraphic commands the robot can execute.
|
||||
- Verb-first. Drop articles, adverbs, qualifiers.
|
||||
- Consistent object nouns (if the task says "cube", every subtask says
|
||||
"cube" — never "block" / "object").
|
||||
- Atomic — one skill per subtask (gripper-open events, contact, regrasps,
|
||||
transitions all become cut points).
|
||||
- Each label must recur across the demos. If you see a motion only
|
||||
once across all sample clips, it probably isn't a canonical phase.
|
||||
- Good: "move to blue cube", "grasp blue cube", "lift blue cube",
|
||||
"place blue cube in box", "release blue cube", "retract arm".
|
||||
- Bad: "the robot arm moves towards the blue cube" (third person,
|
||||
too long), "carefully pick up the cube" (adverb, article),
|
||||
"carrying the yellow cube over the green basket" (gerund — should
|
||||
be imperative "transport yellow cube to green basket").
|
||||
|
||||
2. `memory_milestones`: first-person past-tense sentences the running
|
||||
memory composes from. Each subtask phase that produces a lasting
|
||||
change should have a milestone; transient motions (move, retract)
|
||||
should NOT.
|
||||
- First person, past tense. Start with "I".
|
||||
- One sentence. Functional outcome only — no grasp / motion detail.
|
||||
- Good: "I picked up the blue cube.", "I placed the blue cube in
|
||||
the green box.", "I wiped the counter."
|
||||
- Bad: "The robot arm grasped the blue cube." (third person),
|
||||
"I carefully grasped the blue cube with the parallel gripper."
|
||||
(irrelevant detail), "I moved towards the blue cube." (transient
|
||||
motion — should be omitted, not memorialised).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": ["<verb phrase>", ...],
|
||||
"memory_milestones": ["I <past-tense sentence>.", ...]
|
||||
}}
|
||||
@@ -0,0 +1,36 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
{vocabulary_block}Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -0,0 +1,80 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{vocabulary_block}Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,17 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -0,0 +1,12 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -0,0 +1,46 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
|
||||
|
||||
def gather_data_paths(root: Path) -> list[Path]:
|
||||
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
|
||||
return sorted((root / "data").rglob("*.parquet"))
|
||||
|
||||
|
||||
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
|
||||
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
|
||||
table = pq.read_table(path, columns=["episode_index"])
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
out: dict[int, tuple[int, int]] = {}
|
||||
cur_ep: int | None = None
|
||||
start = 0
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start = i
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
out[cur_ep] = (start, i - start)
|
||||
cur_ep = ep
|
||||
start = i
|
||||
if cur_ep is not None:
|
||||
out[cur_ep] = (start, len(episode_col) - start)
|
||||
return out
|
||||
|
||||
|
||||
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
|
||||
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
|
||||
n = record.row_count
|
||||
if k <= 0 or n == 0:
|
||||
return []
|
||||
if k >= n:
|
||||
return list(range(n))
|
||||
step = (n - 1) / (k - 1) if k > 1 else 0.0
|
||||
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
|
||||
|
||||
|
||||
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
|
||||
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
|
||||
for path in gather_data_paths(root):
|
||||
offsets = episode_offsets_per_path(path)
|
||||
if episode_index in offsets:
|
||||
start, count = offsets[episode_index]
|
||||
return path, start, count
|
||||
return None
|
||||
|
||||
|
||||
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
|
||||
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
|
||||
found = lookup_data_path(root, episode_index)
|
||||
if found is None:
|
||||
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
|
||||
path, start, count = found
|
||||
table = pq.read_table(path, columns=["timestamp"])
|
||||
timestamps = table.column("timestamp").to_pylist()[start : start + count]
|
||||
return path, [float(t) for t in timestamps]
|
||||
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
|
||||
|
||||
def iter_staged_episodes(root: Path) -> Iterator[int]:
|
||||
"""Yield episode indices for which any staging artifact exists."""
|
||||
if not root.exists():
|
||||
return
|
||||
for child in sorted(root.iterdir()):
|
||||
if child.is_dir() and child.name.startswith("episode_"):
|
||||
try:
|
||||
yield int(child.name.removeprefix("episode_"))
|
||||
except ValueError:
|
||||
continue
|
||||
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(
|
||||
row, report, record.episode_index, self.dataset_camera_keys
|
||||
)
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||
)
|
||||
return
|
||||
if (
|
||||
is_view_dependent_style(style)
|
||||
and dataset_camera_keys
|
||||
and camera not in dataset_camera_keys
|
||||
):
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
@@ -0,0 +1,703 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output (``json_mode=True`` enables guided decoding when the
|
||||
backend supports it),
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client per the configured backend.
|
||||
|
||||
For ``stub``, callers should construct :class:`StubVlmClient` directly with
|
||||
a responder callable. ``stub`` here is rejected to make accidental misuse
|
||||
obvious.
|
||||
"""
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend == "vllm":
|
||||
return _make_vllm_client(config)
|
||||
if config.backend == "transformers":
|
||||
return _make_transformers_client(config)
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||
) from exc
|
||||
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
|
||||
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||
# convolution kernels — slower but functional.
|
||||
if os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
_torch.backends.cudnn.enabled = False
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"tensor_parallel_size": config.tensor_parallel_size,
|
||||
"gpu_memory_utilization": config.gpu_memory_utilization,
|
||||
"trust_remote_code": config.trust_remote_code,
|
||||
}
|
||||
if config.max_model_len is not None:
|
||||
llm_kwargs["max_model_len"] = config.max_model_len
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
# ``guided_decoding`` would speed up parsing but its API differs across
|
||||
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
|
||||
# wrapper already has a one-retry JSON-recovery path, so we skip it.
|
||||
params = SamplingParams(max_tokens=max_tok, temperature=temp)
|
||||
# ``llm.chat`` handles chat-template application + multimodal input
|
||||
# extraction (image/video blocks) internally, which ``llm.generate``
|
||||
# does not.
|
||||
outputs = llm.chat([list(m) for m in batch], params)
|
||||
return [o.outputs[0].text for o in outputs]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
import transformers # type: ignore[import-not-found]
|
||||
from transformers import AutoProcessor # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
||||
auto_cls = getattr(transformers, "AutoModelForImageTextToText", None) or getattr(
|
||||
transformers, "AutoModelForVision2Seq", None
|
||||
)
|
||||
if auto_cls is None:
|
||||
raise ImportError(
|
||||
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
|
||||
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
|
||||
"for VL models."
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(config.model_id, trust_remote_code=config.trust_remote_code)
|
||||
use_accelerate = os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
|
||||
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
|
||||
if use_accelerate:
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype=_torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
model = model.to("cuda")
|
||||
model.eval()
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
outs: list[str] = []
|
||||
for messages in batch:
|
||||
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
inputs = processor(text=[text], return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
gen = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tok,
|
||||
temperature=temp,
|
||||
do_sample=temp > 0.0,
|
||||
)
|
||||
decoded = processor.batch_decode(
|
||||
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
|
||||
)[0]
|
||||
outs.append(decoded)
|
||||
return outs
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
|
||||
"""Pass-through hook used by the vllm backend.
|
||||
|
||||
vllm exposes its own multimodal entry points that vary by version; for the
|
||||
base flow we simply forward the raw message list and let the caller's
|
||||
custom backend handle templating. Real deployments override this.
|
||||
"""
|
||||
return list(messages)
|
||||
222
src/lerobot/annotations/steerable_pipeline/vocabulary.py
Normal file
222
src/lerobot/annotations/steerable_pipeline/vocabulary.py
Normal file
@@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dataset-level canonical vocabulary discovery (Phase 0).
|
||||
|
||||
The downstream consumer of these annotations is a low-level action expert
|
||||
conditioned on the ``subtask`` string. Free-form per-episode LLM rephrasing
|
||||
gives near-unique strings per occurrence, which collapses the action
|
||||
expert's conditioning to noise and makes runtime subtask-paraphrase drift
|
||||
catastrophic. The Hi-Robot / π0.6-MEM recipe ships a small canonical
|
||||
vocabulary per environment (~10 strings) that every episode reuses; this
|
||||
module derives that vocabulary automatically from the first few episode
|
||||
videos and persists it next to the dataset.
|
||||
|
||||
Pipeline-level flow:
|
||||
|
||||
Phase 0 (here): watch N sample episodes → produce vocabulary.json
|
||||
Phase 1 (plan module): reuse vocabulary on every episode, both as
|
||||
prompt-side constraint *and* post-VLM validation
|
||||
|
||||
The vocabulary is JSON, lives at ``<root>/meta/canonical_vocabulary.json``,
|
||||
and is human-inspectable / hand-editable — if the discovered set is wrong,
|
||||
operators edit the file and re-run the pipeline without phase 0.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import VocabularyConfig
|
||||
from .frames import FrameProvider, null_provider, to_video_block
|
||||
from .prompts import load as load_prompt
|
||||
from .reader import EpisodeRecord
|
||||
from .vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCABULARY_FILENAME = "canonical_vocabulary.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vocabulary:
|
||||
"""Canonical phrasings shared across every episode of one dataset.
|
||||
|
||||
Both lists are strict: per-episode subtask + memory generation pick
|
||||
from these strings only; the downstream policy then has a small,
|
||||
repeatable target distribution to learn instead of thousands of
|
||||
LLM paraphrases.
|
||||
"""
|
||||
|
||||
subtasks: tuple[str, ...]
|
||||
"""Imperative subtask labels — what the low-level policy is conditioned
|
||||
on. Verb-first, telegraphic, consistent object nouns. Example:
|
||||
``("move to blue cube", "grasp blue cube", "lift blue cube",
|
||||
"place blue cube in box", "retract arm")``.
|
||||
"""
|
||||
|
||||
memory_milestones: tuple[str, ...]
|
||||
"""First-person past-tense milestone sentences — building blocks for
|
||||
the running memory string. Example: ``("I picked up the blue cube.",
|
||||
"I placed the blue cube in the green box.")``. Each milestone maps
|
||||
1:1 onto a completed subtask phase; ``memory_at_step_k`` is the
|
||||
concatenation of milestones for completed phases.
|
||||
"""
|
||||
|
||||
def to_json(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"subtasks": list(self.subtasks),
|
||||
"memory_milestones": list(self.memory_milestones),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, payload: dict[str, Any]) -> Vocabulary:
|
||||
subtasks = tuple(
|
||||
str(s).strip() for s in (payload.get("subtasks") or []) if str(s).strip()
|
||||
)
|
||||
memory_milestones = tuple(
|
||||
str(s).strip() for s in (payload.get("memory_milestones") or []) if str(s).strip()
|
||||
)
|
||||
return cls(subtasks=subtasks, memory_milestones=memory_milestones)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.subtasks and not self.memory_milestones
|
||||
|
||||
|
||||
def vocabulary_path(root: Path) -> Path:
|
||||
"""Return the canonical on-disk location for the vocabulary file."""
|
||||
return root / "meta" / VOCABULARY_FILENAME
|
||||
|
||||
|
||||
def load_vocabulary(root: Path) -> Vocabulary | None:
|
||||
"""Read ``<root>/meta/canonical_vocabulary.json`` if present.
|
||||
|
||||
Returns ``None`` when the file does not exist — callers fall back to
|
||||
free-form (unconstrained) subtask + memory generation, preserving the
|
||||
pipeline's behaviour on datasets that never ran phase 0.
|
||||
"""
|
||||
path = vocabulary_path(root)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
logger.warning("could not read %s: %s — proceeding without vocabulary", path, exc)
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning("%s is not a JSON object — ignoring", path)
|
||||
return None
|
||||
vocab = Vocabulary.from_json(payload)
|
||||
if vocab.is_empty():
|
||||
return None
|
||||
return vocab
|
||||
|
||||
|
||||
def save_vocabulary(root: Path, vocab: Vocabulary) -> Path:
|
||||
"""Atomically persist ``vocab`` to ``<root>/meta/canonical_vocabulary.json``."""
|
||||
path = vocabulary_path(root)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(
|
||||
json.dumps(vocab.to_json(), indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
tmp.replace(path)
|
||||
return path
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabularyDiscoveryModule:
|
||||
"""Derive a dataset-level canonical vocabulary from sample episodes.
|
||||
|
||||
Phase 0 of the executor: pulls ``config.sample_episodes`` episode
|
||||
videos, packs them into one Qwen-VL multi-video prompt, and asks the
|
||||
model to enumerate the small set of canonical subtask labels +
|
||||
memory milestones that recur across them. The output is persisted
|
||||
to ``meta/canonical_vocabulary.json`` and consumed by phase 1.
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VocabularyConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def discover(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
*,
|
||||
existing: Vocabulary | None = None,
|
||||
) -> Vocabulary | None:
|
||||
"""Run vocabulary discovery against the first N sample episodes.
|
||||
|
||||
``existing`` short-circuits the VLM call when ``config.reuse_existing``
|
||||
is True and an on-disk vocabulary is already present — keeps re-runs
|
||||
cheap and lets operators hand-edit the file without it getting
|
||||
overwritten.
|
||||
"""
|
||||
if existing is not None and self.config.reuse_existing:
|
||||
logger.info(
|
||||
"vocabulary: reusing existing (%d subtasks, %d memory milestones)",
|
||||
len(existing.subtasks),
|
||||
len(existing.memory_milestones),
|
||||
)
|
||||
return existing
|
||||
|
||||
sample = list(records[: max(1, int(self.config.sample_episodes))])
|
||||
if not sample:
|
||||
return None
|
||||
|
||||
task_hint = next((r.episode_task for r in sample if r.episode_task), "")
|
||||
prompt = load_prompt("module_0_vocabulary").format(
|
||||
episode_task=task_hint or "(unspecified)",
|
||||
n_episodes=len(sample),
|
||||
)
|
||||
# Pack one video block per sample episode so the VLM sees the
|
||||
# variation across episodes (different starting poses, different
|
||||
# object placements) rather than overfitting to one trajectory.
|
||||
content: list[dict[str, Any]] = []
|
||||
for record in sample:
|
||||
video_frames = self.frame_provider.video_for_episode(
|
||||
record, int(self.config.max_video_frames_per_episode)
|
||||
)
|
||||
if video_frames:
|
||||
content.extend(to_video_block(video_frames))
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
logger.warning("vocabulary: VLM did not return a JSON object — skipping")
|
||||
return None
|
||||
|
||||
vocab = Vocabulary.from_json(result)
|
||||
if vocab.is_empty():
|
||||
logger.warning("vocabulary: VLM returned an empty vocabulary — skipping")
|
||||
return None
|
||||
logger.info(
|
||||
"vocabulary: discovered %d subtask labels + %d memory milestones from %d episodes",
|
||||
len(vocab.subtasks),
|
||||
len(vocab.memory_milestones),
|
||||
len(sample),
|
||||
)
|
||||
return vocab
|
||||
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
@@ -0,0 +1,356 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Surface a friendly error from the writer rather than letting
|
||||
# the raw KeyError bubble out of the dict access below — modules
|
||||
# are expected to always emit ``role``, but the validator
|
||||
# currently doesn't check this so a future bug would otherwise
|
||||
# be hard to triage.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"timestamp": float(row["timestamp"]),
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_rows_for_writer(
|
||||
rows: Iterable[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Helper used by tests/validators to partition a flat row list into
|
||||
(persistent_rows, event_rows) using ``column_for_style``.
|
||||
"""
|
||||
persistent: list[dict[str, Any]] = []
|
||||
events: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
return persistent, events
|
||||
@@ -199,12 +199,13 @@ class OpenCVCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
set_fourcc_after_size_and_fps = platform.system() == "Windows"
|
||||
if self.config.fourcc is not None and not set_fourcc_after_size_and_fps:
|
||||
self._validate_fourcc()
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
|
||||
@@ -222,6 +223,11 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
if self.config.fourcc is not None and set_fourcc_after_size_and_fps:
|
||||
# On Windows with DSHOW, changing the resolution can silently override the FOURCC setting.
|
||||
# Set FOURCC last to make sure the requested pixel format is actually enforced.
|
||||
self._validate_fourcc()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -41,6 +42,7 @@ from ..utils import get_cv2_rotation
|
||||
from .configuration_realsense import RealSenseCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
pkg_name = "pyrealsense2-macosx" if sys.platform == "darwin" else "pyrealsense2"
|
||||
|
||||
|
||||
class RealSenseCamera(Camera):
|
||||
@@ -114,7 +116,7 @@ class RealSenseCamera(Camera):
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
require_package("pyrealsense2", extra="intelrealsense")
|
||||
require_package(pkg_name, extra="intelrealsense", import_name="pyrealsense2")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -99,6 +99,7 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -41,8 +41,12 @@ def cfg_to_group(
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||
else:
|
||||
trainable_tag = f"policy:{cfg.policy.type}"
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
trainable_tag,
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
|
||||
@@ -21,8 +21,10 @@ are intentionally NOT re-exported here to avoid circular dependencies
|
||||
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
"""
|
||||
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -30,6 +32,12 @@ from .types import (
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
@@ -39,9 +47,19 @@ __all__ = [
|
||||
"PolicyFeature",
|
||||
"RTCAttentionSchedule",
|
||||
# Config classes
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
]
|
||||
|
||||
81
src/lerobot/configs/dataset.py
Normal file
81
src/lerobot/configs/dataset.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str = ""
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str = ""
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
|
||||
def stamp_repo_id(self) -> None:
|
||||
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
|
||||
|
||||
Must be called explicitly at dataset *creation* time — not on resume,
|
||||
where the existing ``repo_id`` (already stamped) must be preserved.
|
||||
"""
|
||||
if self.repo_id:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
@@ -117,3 +117,9 @@ class PeftConfig:
|
||||
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
||||
# fine-tuning.
|
||||
r: int = 16
|
||||
|
||||
# Alpha parameter for LoRA scaling (scaling = lora_alpha / r).
|
||||
# In general, a higher alpha means stronger adaptation signal.
|
||||
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||
# Common values are r (alpha == rank) or 2*r.
|
||||
lora_alpha: int | None = None
|
||||
|
||||
@@ -18,8 +18,8 @@ from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
|
||||
from . import parser
|
||||
from .default import EvalConfig
|
||||
from .policies import PreTrainedConfig
|
||||
|
||||
@@ -46,8 +46,11 @@ class EvalPipelineConfig:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import pkgutil
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
@@ -24,6 +26,7 @@ from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
@@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object])
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
# Storage for path args extracted from YAML/JSON config files, so that
|
||||
# get_path_arg() can find them even when they weren't passed via CLI.
|
||||
_config_path_args: dict[str, str] = {}
|
||||
|
||||
# Storage for non-path YAML overrides so validate() can pass them to from_pretrained.
|
||||
_config_yaml_overrides: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]:
|
||||
"""Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"])."""
|
||||
args = []
|
||||
for key, value in d.items():
|
||||
if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY):
|
||||
continue
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, bool):
|
||||
value = str(value).lower()
|
||||
if isinstance(value, dict):
|
||||
args.extend(_flatten_to_cli_args(value, full_key))
|
||||
elif value is not None and not isinstance(value, list):
|
||||
args.append(f"--{full_key}={value}")
|
||||
return args
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
@@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None:
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
result = parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
if result is None:
|
||||
result = _config_path_args.get(field_name)
|
||||
return result
|
||||
|
||||
|
||||
def get_yaml_overrides(field_name: str) -> list[str]:
|
||||
return _config_yaml_overrides.get(field_name, [])
|
||||
|
||||
|
||||
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
@@ -192,6 +225,52 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str:
|
||||
"""Extract `path` fields from a YAML/JSON config before draccus processes it.
|
||||
|
||||
When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config,
|
||||
draccus will fail because ``path`` is not a valid field on policy config classes.
|
||||
This function extracts those path values, stores them in ``_config_path_args`` for
|
||||
later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path.
|
||||
"""
|
||||
config_file = Path(config_path)
|
||||
suffix = config_file.suffix.lower()
|
||||
|
||||
if suffix in (".yaml", ".yml"):
|
||||
with open(config_file) as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
elif suffix == ".json":
|
||||
with open(config_file) as f:
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
return config_path
|
||||
|
||||
if not isinstance(config_data, dict):
|
||||
return config_path
|
||||
|
||||
modified = False
|
||||
for field in path_fields:
|
||||
if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]:
|
||||
_config_path_args[field] = str(config_data[field].pop(PATH_KEY))
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
else:
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
return config_path
|
||||
|
||||
# Write cleaned config to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
||||
if suffix in (".yaml", ".yml"):
|
||||
yaml.dump(config_data, tmp, default_flow_style=False)
|
||||
else:
|
||||
json.dump(config_data, tmp, indent=2)
|
||||
return tmp.name
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
@@ -225,6 +304,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
cli_args = filter_path_args(path_fields, cli_args)
|
||||
# Also extract path fields from the YAML/JSON config file
|
||||
if config_path_cli:
|
||||
config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields)
|
||||
if has_method(argtype, "from_pretrained") and config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
|
||||
206
src/lerobot/configs/recipe.py
Normal file
206
src/lerobot/configs/recipe.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user