mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
41 Commits
d55b581ca1
...
docs/add-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71848ebb2e | ||
|
|
5c98e80430 | ||
|
|
f65f3f7a4a | ||
|
|
8194897994 | ||
|
|
9f437d86b6 | ||
|
|
b74a551d38 | ||
|
|
c0a2e9814d | ||
|
|
bac4f61eae | ||
|
|
f4b834844e | ||
|
|
dfdc48a7f1 | ||
|
|
6a8878a639 | ||
|
|
d38eb89f71 | ||
|
|
7ab4936b1b | ||
|
|
ca8c60a0ed | ||
|
|
3c15fd8537 | ||
|
|
5ebbdf3d05 | ||
|
|
6e035fb169 | ||
|
|
01dcb4c292 | ||
|
|
bd9619dfc3 | ||
|
|
0a4a7c40ad | ||
|
|
ca9028ad64 | ||
|
|
9db9c35cb4 | ||
|
|
fe96b28c74 | ||
|
|
2438df1307 | ||
|
|
f218d5ab30 | ||
|
|
04125492e4 | ||
|
|
e963e5a0c4 | ||
|
|
26ff40ddd7 | ||
|
|
6d269b28c8 | ||
|
|
b607c8458e | ||
|
|
9e83510c99 | ||
|
|
1f7b03f5f2 | ||
|
|
cb8edf17e6 | ||
|
|
5699f6cbf4 | ||
|
|
0e6114ac36 | ||
|
|
c8ce413d73 | ||
|
|
82dffde7fa | ||
|
|
eaf0218bc8 | ||
|
|
a0e52d52fe | ||
|
|
e99c55af4b | ||
|
|
408e0ca763 |
6
.github/workflows/benchmark_tests.yml
vendored
6
.github/workflows/benchmark_tests.yml
vendored
@@ -382,6 +382,7 @@ jobs:
|
||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||
--env.type=robotwin \
|
||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -482,6 +483,7 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -693,6 +695,7 @@ jobs:
|
||||
--env.task=\"\$ROBOMME_TASKS\" \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0] \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -800,6 +803,7 @@ jobs:
|
||||
--env.type=libero_plus \
|
||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -900,6 +904,8 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--env.episode_length=50 \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
|
||||
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -152,13 +152,14 @@ jobs:
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
||||
uv pip install \
|
||||
--torch-backend cpu \
|
||||
--index-url https://test.pypi.org/simple/ \
|
||||
--extra-index-url https://pypi.org/simple \
|
||||
--index-strategy unsafe-best-match \
|
||||
"lerobot[all]==$BASE_VERSION"
|
||||
else
|
||||
echo "Installing release version $VERSION from PyPI..."
|
||||
uv pip install "lerobot[all]==$VERSION"
|
||||
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
|
||||
fi
|
||||
- name: Check lerobot version
|
||||
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
||||
|
||||
4
.github/workflows/stale.yml
vendored
4
.github/workflows/stale.yml
vendored
@@ -19,8 +19,8 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Runs at 02:00
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
# schedule:
|
||||
# - cron: "0 2 * * *"
|
||||
|
||||
env:
|
||||
CLOSE_ISSUE_MESSAGE: >
|
||||
|
||||
@@ -232,6 +232,8 @@ Match the policy to the user's **GPU memory** and **time budget**. Numbers below
|
||||
|
||||
All policies typically train for **5–10 epochs** (see §7).
|
||||
|
||||
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
|
||||
|
||||
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
|
||||
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
|
||||
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
|
||||
|
||||
@@ -109,7 +109,7 @@ lerobot-train \
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
|
||||
|
||||
## Inference & Evaluation
|
||||
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
# Video benchmark
|
||||
|
||||
## Questions
|
||||
|
||||
What is the optimal trade-off between:
|
||||
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies,
|
||||
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
|
||||
|
||||
How to encode videos?
|
||||
|
||||
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
|
||||
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
|
||||
|
||||
## Variables
|
||||
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
### Encoding parameters
|
||||
|
||||
| parameter | values |
|
||||
| ----------- | ------------------------------------------------------------ |
|
||||
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
|
||||
| **pix_fmt** | `yuv444p`, `yuv420p` |
|
||||
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
|
||||
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
|
||||
|
||||
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
|
||||
|
||||
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
|
||||
|
||||
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
|
||||
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
|
||||
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
|
||||
|
||||
### Decoding parameters
|
||||
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
|
||||
- `pyav`
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
|
||||
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
|
||||
|
||||
- `1_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
|
||||
## Metrics
|
||||
|
||||
**Data compression ratio (lower is better)**
|
||||
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Loading time ratio (lower is better)**
|
||||
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average Mean Square Error (lower is better)**
|
||||
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Average Peak Signal to Noise Ratio (higher is better)**
|
||||
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
|
||||
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
<!-- **Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
|
||||
|
||||
## How the benchmark works
|
||||
|
||||
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
|
||||
|
||||
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
|
||||
This gives a unique set of encoding parameters which is used to encode the episode.
|
||||
|
||||
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
|
||||
|
||||
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
|
||||
These are then all concatenated to a single table ready for analysis.
|
||||
|
||||
## Caveats
|
||||
|
||||
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
|
||||
- `torchaudio`
|
||||
- `ffmpegio`
|
||||
- `decord`
|
||||
- `nvc`
|
||||
|
||||
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
|
||||
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
|
||||
|
||||
## Install
|
||||
|
||||
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
|
||||
|
||||
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
|
||||
|
||||
## Adding a video decoder
|
||||
|
||||
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
|
||||
You can easily add a new decoder to benchmark by adding it to this function in the script:
|
||||
|
||||
```diff
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
+ elif backend == ["your_decoder"]:
|
||||
+ return your_decoder_function(
|
||||
+ video_path, timestamps, tolerance_s, backend
|
||||
+ )
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
For a quick run, you can try these parameters:
|
||||
|
||||
```bash
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
--crf 10 40 None \
|
||||
--timestamps-modes 1_frame 2_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 5 \
|
||||
--num-workers 5 \
|
||||
--save-frames 0
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### Reproduce
|
||||
|
||||
We ran the benchmark with the following parameters:
|
||||
|
||||
```bash
|
||||
# h264 and h265 encodings
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
|
||||
# av1 encoding (only compatible with yuv420p and pyav decoder)
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
```
|
||||
|
||||
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
|
||||
|
||||
### Parameters selected for LeRobotDataset
|
||||
|
||||
Considering these results, we chose what we think is the best set of encoding parameter:
|
||||
|
||||
- vcodec: `libsvtav1`
|
||||
- pix-fmt: `yuv420p`
|
||||
- g: `2`
|
||||
- crf: `30`
|
||||
|
||||
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
|
||||
|
||||
### Summary
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
@@ -1,488 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Assess the performance of video decoding in various configurations.
|
||||
|
||||
This script will benchmark different video encoding and decoding parameters.
|
||||
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import itertools
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.utils import TimerManager
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
("vcodec", "libx264"),
|
||||
("pix_fmt", "yuv444p"),
|
||||
("g", 2),
|
||||
("crf", None),
|
||||
# TODO(aliberts): Add fastdecode
|
||||
# ("fastdecode", 0),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
|
||||
def parse_int_or_none(value) -> int | None:
|
||||
if value.lower() == "none":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
|
||||
|
||||
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
total_size = 0
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return torch.stack(frames)
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
case "1_frame":
|
||||
frame_indexes = [idx]
|
||||
case "2_frames":
|
||||
frame_indexes = [idx - 1, idx]
|
||||
case "2_frames_4_space":
|
||||
frame_indexes = [idx - 5, idx]
|
||||
case "6_frames":
|
||||
frame_indexes = [idx - i for i in range(6)][::-1]
|
||||
case _:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
timestamps_mode: str,
|
||||
backend: str,
|
||||
ep_num_images: int,
|
||||
fps: int,
|
||||
num_samples: int = 50,
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int, lock: Lock):
|
||||
time_benchmark = TimerManager(log=False)
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
"psnr_values": [],
|
||||
"ssim_values": [],
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark, lock:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
|
||||
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
|
||||
|
||||
return result
|
||||
|
||||
load_times_video_ms = []
|
||||
load_times_images_ms = []
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
ssim_values = []
|
||||
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
# Use a single shared lock for all worker threads
|
||||
shared_lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
psnr_values.extend(result["psnr_values"])
|
||||
ssim_values.extend(result["ssim_values"])
|
||||
mse_values.extend(result["mse_values"])
|
||||
|
||||
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
|
||||
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
|
||||
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
|
||||
|
||||
return {
|
||||
"avg_load_time_video_ms": avg_load_time_video_ms,
|
||||
"avg_load_time_images_ms": avg_load_time_images_ms,
|
||||
"video_images_load_time_ratio": video_images_load_time_ratio,
|
||||
"avg_mse": float(np.mean(mse_values)),
|
||||
"avg_psnr": float(np.mean(psnr_values)),
|
||||
"avg_ssim": float(np.mean(ssim_values)),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_encoding_decoding(
|
||||
dataset: LeRobotDataset,
|
||||
video_path: Path,
|
||||
imgs_dir: Path,
|
||||
encoding_cfg: dict,
|
||||
decoding_cfg: dict,
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
overwrite: bool = False,
|
||||
seed: int = 1337,
|
||||
) -> list[dict]:
|
||||
fps = dataset.fps
|
||||
|
||||
if overwrite or not video_path.is_file():
|
||||
tqdm.write(f"encoding {video_path}")
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
video_images_size_ratio = video_size_bytes / images_size_bytes
|
||||
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
timestamps_mode,
|
||||
backend,
|
||||
ep_num_images,
|
||||
fps,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
benchmark_row.update(
|
||||
**{
|
||||
"repo_id": dataset.repo_id,
|
||||
"resolution": f"{width} x {height}",
|
||||
"num_pixels": num_pixels,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"images_size_bytes": images_size_bytes,
|
||||
"video_images_size_ratio": video_images_size_ratio,
|
||||
"timestamps_mode": timestamps_mode,
|
||||
"backend": backend,
|
||||
},
|
||||
**encoding_cfg,
|
||||
)
|
||||
benchmark_table.append(benchmark_row)
|
||||
|
||||
return benchmark_table
|
||||
|
||||
|
||||
def main(
|
||||
output_dir: Path,
|
||||
repo_ids: list[str],
|
||||
vcodec: list[str],
|
||||
pix_fmt: list[str],
|
||||
g: list[int],
|
||||
crf: list[int],
|
||||
# fastdecode: list[int],
|
||||
timestamps_modes: list[str],
|
||||
backends: list[str],
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
):
|
||||
check_datasets_formats(repo_ids)
|
||||
encoding_benchmarks = {
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
# "fastdecode": fastdecode,
|
||||
}
|
||||
decoding_benchmarks = {
|
||||
"timestamps_modes": timestamps_modes,
|
||||
"backends": backends,
|
||||
}
|
||||
headers = ["repo_id", "resolution", "num_pixels"]
|
||||
headers += list(BASE_ENCODING.keys())
|
||||
headers += [
|
||||
"timestamps_mode",
|
||||
"backend",
|
||||
"video_size_bytes",
|
||||
"images_size_bytes",
|
||||
"video_images_size_ratio",
|
||||
"avg_load_time_video_ms",
|
||||
"avg_load_time_images_ms",
|
||||
"video_images_load_time_ratio",
|
||||
"avg_mse",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
]
|
||||
file_paths = []
|
||||
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
|
||||
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
|
||||
benchmark_table = []
|
||||
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for duet in [
|
||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
||||
]:
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for key, value in duet.items():
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
now = dt.datetime.now()
|
||||
csv_path = (
|
||||
output_dir
|
||||
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
|
||||
)
|
||||
benchmark_df.to_csv(csv_path, header=True, index=False)
|
||||
file_paths.append(csv_path)
|
||||
del benchmark_df
|
||||
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/video_benchmark"),
|
||||
help="Directory where the video benchmark outputs are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"lerobot/aloha_mobile_shrimp_image",
|
||||
"lerobot/paris_street",
|
||||
"lerobot/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["h264", "hevc", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["yuv444p", "yuv420p"],
|
||||
help="Pixel formats (chroma subsampling) to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
||||
help="Group of pictures sizes to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
|
||||
help="Constant rate factors to be tested.",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--fastdecode",
|
||||
# type=int,
|
||||
# nargs="*",
|
||||
# default=[0, 1],
|
||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
||||
# "For libx264 and libx265/hevc, only 1 is possible. "
|
||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--timestamps-modes",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
],
|
||||
help="Timestamps scenarios to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["torchcodec", "pyav"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of samples for each encoding x decoding config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of processes for parallelized sample processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
@@ -35,7 +35,7 @@ USER root
|
||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
||||
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
|
||||
# Configure the base image for CI with GPU access
|
||||
ARG CUDA_VERSION=12.6.3
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
|
||||
@@ -3,12 +3,16 @@
|
||||
title: LeRobot
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: cheat-sheet
|
||||
title: Cheat sheet
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: lelab
|
||||
title: LeLab - Lerobot GUI
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
title: Bring Your Own Hardware
|
||||
- local: hilserl
|
||||
@@ -24,6 +28,12 @@
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: hardware_guide
|
||||
title: Compute Hardware Guide
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Compute & Hardware"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
@@ -31,8 +41,12 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
@@ -47,6 +61,8 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
@@ -131,6 +147,8 @@
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
- local: rebot_b601
|
||||
title: reBot B601-DM
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
@@ -140,10 +158,6 @@
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Supported Hardware"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
|
||||
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
```
|
||||
|
||||
@@ -1,60 +1,37 @@
|
||||
# Bring Your Own Policies
|
||||
# Adding a Policy
|
||||
|
||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
|
||||
|
||||
## Step 1: Create a Policy Package
|
||||
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
|
||||
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
|
||||
|
||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
|
||||
|
||||
### Package Structure
|
||||
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_custom_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_custom_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_custom_policy.py
|
||||
├── modeling_my_custom_policy.py
|
||||
└── processor_my_custom_policy.py
|
||||
```
|
||||
---
|
||||
|
||||
### Package Configuration
|
||||
## Anatomy of a policy
|
||||
|
||||
Set up your `pyproject.toml`:
|
||||
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_custom_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
### Configuration class
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
||||
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
# configuration_my_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig
|
||||
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@PreTrainedConfig.register_subclass("my_policy")
|
||||
@dataclass
|
||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyCustomPolicy.
|
||||
class MyPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyPolicy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of observation steps to use as input
|
||||
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
raise ValueError("n_action_steps cannot exceed horizon")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
"""Validate input/output feature compatibility.
|
||||
|
||||
Call this explicitly from your policy's __init__ — the base class does not.
|
||||
"""
|
||||
if not self.image_features:
|
||||
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
||||
raise ValueError("MyPolicy requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
||||
raise ValueError("MyPolicy requires 'action' in output_features.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
||||
"""
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns."""
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
return None
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
|
||||
|
||||
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
||||
### Policy class
|
||||
|
||||
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
# modeling_my_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Any
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_custom_policy"
|
||||
class MyPolicy(PreTrainedPolicy):
|
||||
config_class = MyPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features() # not called automatically by the base class
|
||||
self.config = config
|
||||
self.model = ... # your nn.Module here
|
||||
|
||||
def reset(self):
|
||||
"""Reset episode state."""
|
||||
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
|
||||
...
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
|
||||
...
|
||||
|
||||
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Return a single action for the current timestep (called at inference)."""
|
||||
"""Return a single action for the current timestep (called every step at inference)."""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
|
||||
"""Compute the training loss.
|
||||
|
||||
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
|
||||
logging-friendly Python natives (no tensors with gradients).
|
||||
|
||||
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||
timesteps padded because the episode ended before `horizon` steps, you
|
||||
timesteps padded because the episode ended before `horizon` steps; you
|
||||
can exclude those from your loss.
|
||||
"""
|
||||
actions = batch[ACTION]
|
||||
action_is_pad = batch.get("action_is_pad")
|
||||
...
|
||||
return {"loss": ...}
|
||||
return loss, {"some_loss_component": some_loss_component.item()}
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
The methods called by the train/eval loops:
|
||||
|
||||
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
| Method | Used by | What it does |
|
||||
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
|
||||
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
|
||||
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
|
||||
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
|
||||
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
|
||||
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
|
||||
|
||||
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
|
||||
|
||||
### Processor functions
|
||||
|
||||
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
# processor_my_policy.py
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
def make_my_policy_pre_post_processors(
|
||||
config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
**Important — function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
|
||||
## Step 5: Package Initialization
|
||||
---
|
||||
|
||||
Expose your classes in the package's `__init__.py`:
|
||||
## Path A: Out-of-tree plugin
|
||||
|
||||
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
|
||||
|
||||
### Package structure
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_policy.py
|
||||
├── modeling_my_policy.py
|
||||
└── processor_my_policy.py
|
||||
```
|
||||
|
||||
### `pyproject.toml`
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
### Package `__init__.py`
|
||||
|
||||
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
@@ -204,44 +239,148 @@ except ImportError:
|
||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||
)
|
||||
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .modeling_my_custom_policy import MyCustomPolicy
|
||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
from .modeling_my_policy import MyPolicy
|
||||
from .processor_my_policy import make_my_policy_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"MyCustomPolicyConfig",
|
||||
"MyCustomPolicy",
|
||||
"make_my_custom_policy_pre_post_processors",
|
||||
"MyPolicyConfig",
|
||||
"MyPolicy",
|
||||
"make_my_policy_pre_post_processors",
|
||||
]
|
||||
```
|
||||
|
||||
## Step 6: Installation and Usage
|
||||
|
||||
### Install Your Policy Package
|
||||
### Install and use
|
||||
|
||||
```bash
|
||||
cd lerobot_policy_my_custom_policy
|
||||
cd lerobot_policy_my_policy
|
||||
pip install -e .
|
||||
|
||||
# Or install from PyPI if published
|
||||
pip install lerobot_policy_my_custom_policy
|
||||
pip install lerobot_policy_my_policy
|
||||
```
|
||||
|
||||
### Use Your Policy
|
||||
|
||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type my_custom_policy \
|
||||
--policy.type my_policy \
|
||||
--env.type pusht \
|
||||
--steps 200000
|
||||
```
|
||||
|
||||
## Examples and Community Contributions
|
||||
---
|
||||
|
||||
## Path B: Contributing in-tree
|
||||
|
||||
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
|
||||
|
||||
### In-tree layout
|
||||
|
||||
```
|
||||
src/lerobot/policies/my_policy/
|
||||
├── __init__.py # re-exports config + modeling + processor factory
|
||||
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
|
||||
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
|
||||
├── processor_my_policy.py # make_my_policy_pre_post_processors
|
||||
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
|
||||
```
|
||||
|
||||
Two notes:
|
||||
|
||||
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
|
||||
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
|
||||
|
||||
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
|
||||
|
||||
### Wiring
|
||||
|
||||
Three places need to know about your policy. All by name.
|
||||
|
||||
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
|
||||
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
|
||||
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
|
||||
|
||||
Mirror an existing policy that's structurally similar to yours; the diff is small.
|
||||
|
||||
### Heavy / optional dependencies
|
||||
|
||||
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
|
||||
|
||||
```python
|
||||
from typing import TYPE_CHECKING
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
else:
|
||||
DDIMScheduler = None # keeps the symbol bindable at import time
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(self, config):
|
||||
require_package("diffusers", extra="diffusion")
|
||||
super().__init__(config)
|
||||
...
|
||||
```
|
||||
|
||||
This way:
|
||||
|
||||
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
|
||||
- Type checkers see the real symbol.
|
||||
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
|
||||
|
||||
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
|
||||
|
||||
### Benchmarks and a published checkpoint
|
||||
|
||||
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
|
||||
|
||||
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
|
||||
|
||||
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
|
||||
|
||||
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
|
||||
|
||||
```markdown
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with `lerobot/<policy>_libero`:
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| -------------- | -----------: | ---------: |
|
||||
| libero_spatial | 87.5% | 50 |
|
||||
| libero_object | 93.0% | 50 |
|
||||
| libero_goal | 81.5% | 50 |
|
||||
| libero_10 | 62.0% | 50 |
|
||||
| **average** | **81.0%** | 200 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
|
||||
```
|
||||
|
||||
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
|
||||
|
||||
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
|
||||
|
||||
### PR checklist
|
||||
|
||||
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
|
||||
|
||||
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
|
||||
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
|
||||
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
|
||||
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
|
||||
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
|
||||
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
|
||||
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
|
||||
|
||||
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
|
||||
|
||||
---
|
||||
|
||||
## Examples and community contributions
|
||||
|
||||
Check out these example policy implementations:
|
||||
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) — Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
|
||||
Share your policy implementations with the community! 🤗
|
||||
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
|
||||
|
||||
139
docs/source/cheat-sheet.mdx
Normal file
139
docs/source/cheat-sheet.mdx
Normal file
@@ -0,0 +1,139 @@
|
||||
# Cheat sheet
|
||||
|
||||
All of the LeRobot commands in one place. If you forgot how to use a specific command or want to learn about a new one you can do it here.
|
||||
|
||||
> [!WARNING]
|
||||
> For all of the commands listed below remember to change the ports/names/ids to your own values!
|
||||
|
||||
> [!TIP]
|
||||
> Another great way to look at all the commands and get them configured for your specific setup is to use this [Jupyter Notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb).
|
||||
|
||||
### Setup and installation
|
||||
|
||||
For installation please look at [LeRobot Installation](https://huggingface.co/docs/lerobot/main/en/installation).
|
||||
|
||||
### Useful tools
|
||||
|
||||
###### Find port
|
||||
|
||||
Use this to identify which serial ports your robots are connected to. Follow the instructions in your terminal: you will be asked to unplug the USB cable and press Enter. The script will then detect and print the correct serial port for that robot.
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
###### Find cameras
|
||||
|
||||
Quickly find camera indices and verify their output. This command prints camera information to the terminal and saves test frames from each detected camera to `lerobot/outputs/captured_images`
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
In most cases you will need to perform calibration just once for each robot and teleoperation device. Before performing the calibration make sure that all the joints are roughly in the middle position.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm
|
||||
```
|
||||
|
||||
Make sure that you use the same IDs used during calibration later for the other scripts. That's how LeRobot finds the calibration files.
|
||||
|
||||
### Teleoperation
|
||||
|
||||
Teleoperating with two cameras and displaying the data with Rerun.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
### Recording a dataset
|
||||
|
||||
The dataset is automatically uploaded to the server and saved under repo_id, make sure you are logged in to your HF account with CLI:
|
||||
`hf auth login`
|
||||
|
||||
You can get the token from: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--dataset.num_episodes=30 \
|
||||
--dataset.single_task="put the red brick in a bowl" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
While collecting the dataset you can control the process with your keyboard:
|
||||
Control the data recording flow using keyboard shortcuts:
|
||||
|
||||
- Press **Right Arrow (`→`)**: Save episode and move to the next.
|
||||
- Press **Left Arrow (`←`)**: Delete current episode and retry.
|
||||
- Press **Escape (`ESC`)**: Stop, encode videos, and upload.
|
||||
|
||||
### Training
|
||||
|
||||
Depending on your hardware training the policy might take a few hours. That's how you train simple `ACT` policy:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--job_name=act_so101_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
- Policy Types: `act`, `diffusion`, `smolvla`, `pi05`
|
||||
- Devices: `cuda` (NVIDIA), `mps` (Apple Silicon), `cpu`
|
||||
|
||||
If you want to fine-tune a specific model you can provide the path to the model. In this case path is enough and type can be skipped.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.path=username/the_policy_to_finetune \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
|
||||
|
||||
> [!TIP]
|
||||
> If you are using the previous release V0.5.1 instead of `lerobot-rollout` you need to use `lerobot-record`. More information [here](https://huggingface.co/docs/lerobot/v0.5.1/en/il_robots#run-inference-and-evaluate-your-policy).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
@@ -1,277 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -194,7 +194,7 @@ lerobot-record \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
168
docs/source/eo1.mdx
Normal file
168
docs/source/eo1.mdx
Normal file
@@ -0,0 +1,168 @@
|
||||
# EO-1
|
||||
|
||||
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||
alt="An overview of EO-1"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=eo1` configuration through LeRobot
|
||||
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||
|
||||
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EO-1 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1]"
|
||||
```
|
||||
|
||||
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1,libero]"
|
||||
```
|
||||
|
||||
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EO-1 expects a LeRobot dataset with:
|
||||
|
||||
- At least one visual observation, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction through the dataset `task` field
|
||||
|
||||
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=eo1
|
||||
```
|
||||
|
||||
By default, a new EO-1 policy initializes its backbone from:
|
||||
|
||||
```python
|
||||
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||
```
|
||||
|
||||
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-eo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=eo1 \
|
||||
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.attn_implementation=sdpa \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--output_dir=./outputs/eo1_training \
|
||||
--job_name=eo1_training \
|
||||
--steps=300000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||
|
||||
## Evaluation
|
||||
|
||||
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20
|
||||
```
|
||||
|
||||
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Processing
|
||||
|
||||
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||
|
||||
### State and Action Dimensions
|
||||
|
||||
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||
|
||||
### Attention Backend
|
||||
|
||||
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||
|
||||
## References
|
||||
|
||||
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{eo1,
|
||||
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||
journal={arXiv preprint},
|
||||
year={2025},
|
||||
url={https://arxiv.org/abs/2508.21112}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -119,14 +121,12 @@ lerobot-record \
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
--duration=600
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
98
docs/source/hardware_guide.mdx
Normal file
98
docs/source/hardware_guide.mdx
Normal file
@@ -0,0 +1,98 @@
|
||||
# Compute HW Guide for LeRobot Training
|
||||
|
||||
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
|
||||
|
||||
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
|
||||
|
||||
## Memory by policy group
|
||||
|
||||
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30–100% over a forward+backward pass alone.
|
||||
|
||||
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
|
||||
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
|
||||
| Light BC | `act`, `vqbet`, `tdmpc` | ~2–6GB | Laptop GPU (RTX 3060), L4, A10G |
|
||||
| Diffusion | `diffusion`, `multi_task_dit` | ~8–14GB | RTX 4070+ / L4 / A10G |
|
||||
| Small VLA | `smolvla` | ~10–16GB | RTX 4080+ / L4 / A10G |
|
||||
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~24–40GB | A100 40 GB+ (24 GB tight at BS 1) |
|
||||
| Multimodal | `groot`, `eo1` | ~24–40GB | A100 40 GB+ |
|
||||
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
|
||||
|
||||
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
|
||||
|
||||
## Training time
|
||||
|
||||
Robotics imitation learning typically converges in **5–10 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
|
||||
|
||||
```text
|
||||
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
|
||||
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
|
||||
total_steps = epochs × steps_per_epoch
|
||||
wall_clock ≈ total_steps × per_step_time
|
||||
```
|
||||
|
||||
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
|
||||
|
||||
### Common scenarios
|
||||
|
||||
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
|
||||
|
||||
| Setup | Policy | Batch | Wall-clock |
|
||||
| ------------------------------------ | -------------- | ----- | ---------: |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~30–60min |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~2–4h |
|
||||
| Single L4 / A10G (24 GB) | `act` | 8 | ~1–2h |
|
||||
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~3–6h |
|
||||
| Single A100 40 GB | `smolvla` | 16 | ~1–2h |
|
||||
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~4–8h |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~30–60min |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~1–2h |
|
||||
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~6–14h |
|
||||
|
||||
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
|
||||
|
||||
### Multi-GPU matters a lot
|
||||
|
||||
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
|
||||
|
||||
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
|
||||
|
||||
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
|
||||
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
|
||||
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
|
||||
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
|
||||
|
||||
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
|
||||
|
||||
### Schedule and checkpoints
|
||||
|
||||
If you shorten training (e.g. 5k–10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
|
||||
|
||||
## Where to run
|
||||
|
||||
VRAM is the first filter. Within a tier, pick by budget and availability — the `$`–`$$$$` columns are relative; check current pricing on the provider you actually use.
|
||||
|
||||
| Class | VRAM | Tier | Comfortable for |
|
||||
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
|
||||
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
|
||||
| L4 / A10G (cloud) | 24 GB | `$–$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
|
||||
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
|
||||
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
|
||||
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
|
||||
|
||||
### Hugging Face Jobs
|
||||
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
|
||||
|
||||
```bash
|
||||
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
|
||||
bash -c "nvidia-smi && lerobot-train \
|
||||
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
|
||||
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||
@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
|
||||
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
||||
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
||||
```
|
||||
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
|
||||
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
|
||||
|
||||
**Example Configuration**
|
||||
|
||||
```json
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.20, 0.10],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.2, 0.1],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
|
||||
|
||||
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
|
||||
|
||||
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
|
||||
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
# bounds for the end-effector in x,y,z direction
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
# maximum step size for the end-effector in x,y,z direction
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
class HILSerlProcessorConfig:
|
||||
...
|
||||
# maximum gripper position that the gripper will be open at
|
||||
max_gripper_pos: float | None = 100.0
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
|
||||
|
||||
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
|
||||
|
||||
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
|
||||
**Collecting a Dataset for the reward classifier**
|
||||
|
||||
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
|
||||
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
||||
|
||||
@@ -658,7 +661,7 @@ Example configuration section for data collection:
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
@@ -671,7 +674,7 @@ Example configuration section for data collection:
|
||||
|
||||
**Reward Classifier Configuration**
|
||||
|
||||
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
|
||||
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
|
||||
|
||||
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
|
||||
- **model_type**: `"cnn"` or `"transformer"`
|
||||
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"root": null
|
||||
},
|
||||
"policy": {
|
||||
"reward_model": {
|
||||
"type": "reward_classifier",
|
||||
"model_name": "helper2424/resnet10",
|
||||
"model_type": "cnn",
|
||||
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"dropout_rate": 0.1,
|
||||
"learning_rate": 1e-4,
|
||||
"device": "cuda",
|
||||
"use_amp": true,
|
||||
"input_features": {
|
||||
"observation.images.front": {
|
||||
"type": "VISUAL",
|
||||
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
||||
|
||||
**Configuration Setup**
|
||||
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
|
||||
|
||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
||||
2. Set `dataset` to your cropped dataset
|
||||
3. Configure environment settings with crop parameters
|
||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
|
||||
3. Set `dataset` to your cropped dataset
|
||||
4. Configure environment settings with crop parameters
|
||||
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
**Starting the Learner**
|
||||
|
||||
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
||||
|
||||
Some configuration values have a disproportionate impact on training stability and speed:
|
||||
|
||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -278,6 +278,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -193,7 +207,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
@@ -202,10 +216,11 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
|
||||
|
||||
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
||||
|
||||
### PyTorch CUDA variant (Linux only)
|
||||
|
||||
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
|
||||
<hfoptions id="cuda_variant">
|
||||
<hfoption id="uv-source">
|
||||
|
||||
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
|
||||
|
||||
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
|
||||
|
||||
To override for a different CUDA variant:
|
||||
|
||||
```bash
|
||||
uv pip install --force-reinstall torch torchvision \
|
||||
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="pip-conda">
|
||||
|
||||
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
|
||||
|
||||
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
|
||||
|
||||
To pick a specific CUDA variant:
|
||||
|
||||
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
|
||||
|
||||
```bash
|
||||
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
|
||||
pip install -e ".[all]" # source
|
||||
# — or —
|
||||
pip install lerobot # from PyPI
|
||||
```
|
||||
|
||||
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
|
||||
|
||||
```bash
|
||||
uv pip install --torch-backend cu128 lerobot
|
||||
```
|
||||
|
||||
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
147
docs/source/language_and_recipes.mdx
Normal file
147
docs/source/language_and_recipes.mdx
Normal file
@@ -0,0 +1,147 @@
|
||||
# Language columns and recipes
|
||||
|
||||
Most LeRobot datasets ship with a single `task` string per episode — fine for
|
||||
short, single-instruction skills, but not enough for the longer-horizon,
|
||||
multi-modal robot policies the field is moving toward (high-level planning,
|
||||
memory, interjections, VQA, tool use). To support those policies without
|
||||
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
|
||||
language columns and a small recipe layer that turns those rows into
|
||||
chat-style training samples on the fly.
|
||||
|
||||
The design splits cleanly into three layers:
|
||||
|
||||
1. **Data in the dataset** — language annotations stored next to frames in
|
||||
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
|
||||
and `language_events`). Datasets without these columns keep their existing
|
||||
behavior.
|
||||
2. **Recipe** — a YAML file that declares which annotation rows to bind and
|
||||
how to lay them out as chat turns (`role`, `content`, optional images,
|
||||
optional tool calls). Recipes are pure config; no Python required to add a
|
||||
new one.
|
||||
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
|
||||
recipe against the per-frame annotations and emits HF-style `messages` plus
|
||||
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
|
||||
that policy processors consume.
|
||||
|
||||
This page describes each layer in turn.
|
||||
|
||||
## Layer 1 — language columns in the dataset
|
||||
|
||||
The two optional columns live next to frame data in
|
||||
`data/chunk-*/file-*.parquet`:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float32 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||
the matching `observation.images.*` feature key. Rows of every other style —
|
||||
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||
enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
### Architecture
|
||||
|
||||
The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
### View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${vqa}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
if_present: vqa,
|
||||
}
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
42
docs/source/lelab.mdx
Normal file
42
docs/source/lelab.mdx
Normal file
@@ -0,0 +1,42 @@
|
||||
# LeLab - LeRobot Guide
|
||||
Graphical user interfaces are the easiest to use for beginners because it's easy to just click everything without remembering the proper commands. That's why we built LeLab which is a GUI built on top of the LeRobot library. With this app you will be able to add robots, collect datasets, train and deploy models.
|
||||
|
||||
### Installation
|
||||
To install lerobot you can simply copy the following command and paste into your terminal. For it to work you need to have `uv` installed, [here is how to do it.](https://docs.astral.sh/uv/getting-started/installation/)
|
||||
```
|
||||
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||
```
|
||||
|
||||
Once installed you will be able to run lelab anytime you want with `lelab` command from your terminal (above command has it included at the end so it will run it right after installation).
|
||||
|
||||
### Adding robots
|
||||
|
||||
##### Calibration
|
||||
You will need to select the proper arm type (leader or follower) and calibrate each arm as shown in the video available inside LeLab. Make sure that all joints are in the middle position when starting the calibration.
|
||||
|
||||
##### Adding cameras
|
||||
At the bottom of the add robot page you can also add the cameras and name them accordingly.
|
||||
|
||||
### Teleoperation
|
||||
Once the robots have been configured you can go back and click the teleoperation button. You will see the 3D visualization of the arm and will be able to control the follower with the leader. If something doesn't work there, remove and add your robot again following the steps described in LeLab.
|
||||
|
||||
### Recording a dataset
|
||||
Type a new name for your dataset and press on the plus button. You will need to provide:
|
||||
- Task description, for example "put the cube in a container"
|
||||
- Number of episodes that you want to record, at least 30 recommended
|
||||
- Episode and reset durations. These are max durations and can be shortened while recording with a spacebar press.
|
||||
- If you configured your cameras earlier you don't need to do that again.
|
||||
|
||||
Press start recording, wait for it to load, perform the task with confident movements but don't rush. Once the task is finished and you moved your robot to the initial position press the spacebar. You will have time to reset the environment for example grab the cube from the container and placing it on the desk again. Once ready press the spacebar and record the next episode. Repeat until all the episodes are recorded.
|
||||
|
||||
### Training a model
|
||||
This is the most powerful function with LeLab! You can easily train models locally on your own computer but also with [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) which gives you easy access to very powerful GPUs with clear pricing.
|
||||
|
||||
> [!TIP]
|
||||
> To use HF Jobs make sure that you are logged in to HF, you can do that by running `hf auth login` in the terminal.
|
||||
|
||||
In the training tab select if you want to train locally or specific HF hardware you want to use. You will also need to provide the dataset that will be used for training. Your own datasets will be listed in a dropdown list, you can also use other datasets by providing its id. Set the policy you want to train, batch size and number of steps. For guide on choosing hardware and batch size check out our [Compute HW Guide for LeRobot Training.](hardware_guide.mdx)
|
||||
|
||||
Once you start training the progress will be visualized inside LeLab. Checkpoints will be saved as well.
|
||||
### Running the model on a robot
|
||||
In the main view of the LeLab under jobs you will see all the models that you trained. To run the policy on the robot just click the green run button and press start inference. After loading the policy the robot should start solving the task that it learned during training.
|
||||
@@ -10,6 +10,7 @@ This docs will guide you to:
|
||||
- Stream datasets without downloading using `StreamingLeRobotDataset`
|
||||
- Apply image transforms for data augmentation during training
|
||||
- Migrate existing `v2.1` datasets to `v3.0`
|
||||
- Experiment with other `LeRobotDataset` formats and implementations like Lance
|
||||
|
||||
## What’s new in `v3`
|
||||
|
||||
@@ -43,7 +44,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
@@ -315,3 +316,39 @@ Dataset v3.0 uses incremental parquet writing with buffered metadata for efficie
|
||||
- Ensures the dataset is valid for loading
|
||||
|
||||
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
|
||||
|
||||
## Other formats and implementations
|
||||
|
||||
### Lance
|
||||
|
||||
Lance is a useful format for multimodal AI datasets, especially for large-scale training requiring high performance IO and random access.
|
||||
|
||||
The `lerobot-lancedb` package implements `LeRobotLanceDataset` (for JPEG images) and `LeRobotLanceVideoDataset` (for mp4 videos).
|
||||
Those two storage layouts both subclass LeRobotDataset and can provide data loading speed ups.
|
||||
|
||||
`LeRobotLanceDataset` is a drop-in replacement for `LeRobotDataset`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot_lancedb import LeRobotLanceDataset, LeRobotLanceVideoDataset
|
||||
|
||||
cfg = DiffusionConfig(...)
|
||||
meta = LeRobotDatasetMetadata(root=local_dataset_path) # or use repo_id=... to load metadata from the Hub
|
||||
delta_timestamps = {...}
|
||||
|
||||
# Use LeRobotLanceDataset for image datasets
|
||||
dataset = LeRobotLanceDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
# Or use LeRobotLanceVideoDataset for video datasets:
|
||||
dataset = LeRobotLanceVideoDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
```
|
||||
|
||||
Join the discussion on [Github](https://github.com/huggingface/lerobot/issues/3608) and explore the `lerobot-lancedb` documentation [here](https://lancedb.github.io/lerobot-lancedb/).
|
||||
|
||||
@@ -28,13 +28,15 @@ lerobot-train \
|
||||
--steps=100000 \
|
||||
--batch_size=32 \
|
||||
--peft.method_type=LORA \
|
||||
--peft.r=64
|
||||
--peft.r=64 \
|
||||
--peft.lora_alpha=64
|
||||
```
|
||||
|
||||
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
||||
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
||||
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter, and the LoRA scaling factor with
|
||||
`--peft.lora_alpha` (where `scaling = lora_alpha / r`). The higher the rank
|
||||
the closer you get to full fine-tuning
|
||||
|
||||
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
||||
|
||||
@@ -161,7 +161,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
186
docs/source/rebot_b601.mdx
Normal file
186
docs/source/rebot_b601.mdx
Normal file
@@ -0,0 +1,186 @@
|
||||
# reBot B601-DM
|
||||
|
||||
[reBot B601-DM](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/) is an open-source, low-cost robot arm from Seeed Studio for embodied-AI and imitation learning. It comes as a **follower** arm (the `B601-DM`, a 6-DOF arm plus gripper driven by Damiao CAN motors) and a **leader** arm (the `StarArm102` / `reBot Arm 102`, driven by FashionStar UART smart servos) used to teleoperate it.
|
||||
|
||||
This page covers **calibration** and **teleoperation** for both single-arm and bimanual (dual-arm) setups.
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/b601dm_zeroposition.jpg"
|
||||
alt="reBot B601-DM follower arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/102_zeroposition.jpg"
|
||||
alt="reBot Arm 102 leader arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
_Left: the B601-DM follower at its zero position. Right: the reBot Arm 102 leader at its zero position. Images courtesy of [Seeed Studio](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/)._
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the reBot support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[rebot]"
|
||||
```
|
||||
|
||||
This pulls in `motorbridge` (CAN motor control for the B601-DM follower) and `motorbridge-smart-servo` (FashionStar UART servos for the reBot Arm 102 leader).
|
||||
|
||||
## Registered device types
|
||||
|
||||
| Type | Kind |
|
||||
| ------------------------ | -------------------------------------------- |
|
||||
| `rebot_b601_follower` | single-arm B601-DM follower robot |
|
||||
| `bi_rebot_b601_follower` | bimanual (dual-arm) follower robot |
|
||||
| `rebot_102_leader` | single-arm reBot Arm 102 leader teleoperator |
|
||||
| `bi_rebot_102_leader` | bimanual (dual-arm) leader teleoperator |
|
||||
|
||||
The bimanual types compose two single-arm instances and namespace each arm's
|
||||
observation/action keys with a `left_` / `right_` prefix. Per-arm settings are
|
||||
passed through nested `left_arm_config.*` / `right_arm_config.*` arguments.
|
||||
|
||||
## Find the USB ports
|
||||
|
||||
For each device, find the USB port associated with its motor bus using:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
On Linux, remove `brltty` (`sudo apt remove brltty`) so it does not hold the
|
||||
leader's USB serial port. You may also need to grant access to the serial
|
||||
devices: `sudo chmod 666 /dev/ttyACM* /dev/ttyUSB*`.
|
||||
</Tip>
|
||||
|
||||
## Calibration
|
||||
|
||||
Neither arm stores a persistent hardware calibration: every time it connects, the motors are re-zeroed against the pose the arm is physically holding. Calibration simply records that zero pose. When prompted, **manually move the arm to its zero position** (the default sit-down pose shown above, gripper fully closed) and press <kbd>ENTER</kbd>.
|
||||
|
||||
### Follower (B601-DM)
|
||||
|
||||
<hfoptions id="calibrate-follower">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
Connect the bimanual follower; calibration runs for the left arm, then the right arm.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao
|
||||
```
|
||||
|
||||
Per-arm calibration files are saved with `_left` / `_right` suffixes on the id.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader (reBot Arm 102)
|
||||
|
||||
<hfoptions id="calibrate-leader">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperation
|
||||
|
||||
Once both arms are calibrated, drive the follower with the leader. The follower talks to its CAN bus through a Damiao serial bridge (`can_adapter=damiao`, the default) or a SocketCAN adapter (`can_adapter=socketcan`). See the [OpenArm page](./openarm) for more details on the SocketCAN adapter configuration.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
The bimanual leader and follower reuse the single-arm classes; each arm is
|
||||
configured through nested `left_arm_config.*` / `right_arm_config.*` arguments,
|
||||
so a bimanual reBot Arm 102 leader drives a bimanual B601-DM follower.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip>
|
||||
The leader and follower share the same joint names (`shoulder_pan,
|
||||
shoulder_lift, elbow_flex, wrist_flex, wrist_yaw, wrist_roll, gripper`), so
|
||||
leader actions map directly onto the follower.
|
||||
</Tip>
|
||||
|
||||
If the motion of a joint is reversed, flip its sign in the leader's `joint_directions` (the gripper also carries a scale to widen its range to the follower):
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.joint_directions='{"shoulder_pan":-1,"shoulder_lift":-1,"elbow_flex":1,"wrist_flex":1,"wrist_yaw":1,"wrist_roll":-1,"gripper":-6}'
|
||||
```
|
||||
|
||||
## Recording datasets
|
||||
|
||||
Swap `lerobot-teleoperate` for `lerobot-record` (with the same `--robot.*` / `--teleop.*` arguments, plus `--dataset.*`) to record demonstrations for training. See [Imitation Learning for Robots](./il_robots) for the full workflow.
|
||||
|
||||
For hardware assembly and wiring, see the [Seeed Studio reBot wiki](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/).
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `vcodec` | `--dataset.camera_encoder.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `30` | Max buffered frames per camera (~1s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
@@ -48,7 +48,7 @@ This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
### Backpressure and Frame Dropping
|
||||
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 30 frames). When the encoder can't keep up:
|
||||
|
||||
1. The queue fills up (consuming RAM)
|
||||
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
|
||||
@@ -82,15 +82,15 @@ Use HW encoding when:
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
@@ -100,15 +100,15 @@ Use HW encoding when:
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
@@ -146,7 +146,7 @@ On very constrained systems, streaming encoding may compete too heavily with the
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
lerobot-record --dataset.camera_encoder.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
210
docs/source/tools.mdx
Normal file
210
docs/source/tools.mdx
Normal file
@@ -0,0 +1,210 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives.
|
||||
2. How the annotation pipeline produces tool-call atoms.
|
||||
3. How to add your own tool.
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
|
||||
not part of the catalog layer described here; today this layer ships
|
||||
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
The catalog above describes _what can be called_. The actual _call_ — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${current_plan}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
tool_calls_from: speech,
|
||||
}
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text _and_ the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which is not part of the catalog layer shipped
|
||||
> today — those modules don't yet exist in the tree. Step 1 alone is
|
||||
> enough to make the tool visible to the chat template via
|
||||
> `meta.tools` so the model can learn to _generate_ the call;
|
||||
> executing the call at inference requires the runtime layer.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk _before_ running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Short label for the saved image."
|
||||
}
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
|
||||
only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py`:
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
If you want to use a tool _without_ writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to _generate_ the call. Steps 2 and 3 are only
|
||||
needed to actually _execute_ it at inference.
|
||||
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
--operation.camera_encoder.vcodec libsvtav1 \
|
||||
--operation.camera_encoder.pix_fmt yuv420p \
|
||||
--operation.camera_encoder.g 2 \
|
||||
--operation.camera_encoder.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
@@ -147,11 +147,7 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
|
||||
117
docs/source/video_encoding_parameters.mdx
Normal file
117
docs/source/video_encoding_parameters.mdx
Normal file
@@ -0,0 +1,117 @@
|
||||
# Video encoding parameters
|
||||
|
||||
When video storage is enabled, LeRobot stores each camera stream as an **MP4** file instead of saving one image file per timestep. Video encoding compresses across time, which usually cuts dataset size and I/O compared to a pile of PNG, while keeping MP4 — a format every player and loader understands.
|
||||
|
||||
Encoding frames into an MP4 is a full FFmpeg pipeline: choice of encoder, pixel format, GOP/keyframes, quality vs. speed, and optional extra encoder flags. Most of these knobs are user-tunable through `camera_encoder`, a nested `VideoEncoderConfig` (`lerobot.configs.video.VideoEncoderConfig`) passed through PyAV.
|
||||
|
||||
You can set these parameters from the CLI with `--dataset.camera_encoder.<field>` (e.g. with `lerobot-record` or `lerobot-rollout`). The same block applies to every camera video stream in that run.
|
||||
|
||||
<Tip>
|
||||
Video storage must be on for `camera_encoder` to have any effect —
|
||||
`use_videos=True` in Python APIs, or `--dataset.video=true` on the CLI (the
|
||||
recording default). With video off, inputs stay as images and `camera_encoder`
|
||||
is ignored.
|
||||
</Tip>
|
||||
|
||||
For details on **when** frames are written vs. encoded (streaming vs. post-episode), queues, and other top-level `--dataset.*` switches, see [Streaming Video Encoding](./streaming_video_encoding). For an encoding-parameter comparison and experiments, see the [video-benchmark Space](https://huggingface.co/spaces/lerobot/video-benchmark).
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--dataset.camera_encoder.vcodec=h264 \
|
||||
--dataset.camera_encoder.preset=fast \
|
||||
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tuning parameters
|
||||
|
||||
<Tip warning={true}>
|
||||
The defaults are tuned to balance **compression ratio**, **visual quality**, and **decoding/seek speed** for typical robotics datasets. Changing them can affect both recording (CPU load, frame drops) and training (decoding throughput, image quality).
|
||||
|
||||
Only override these parameters if you have a specific reason to, and measure the impact on your pipeline before relying on the new settings.
|
||||
|
||||
</Tip>
|
||||
|
||||
All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| --------------- | ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vcodec` | `str` | `"libsvtav1"` | Video codec name. `"auto"` picks the first available hardware encoder from a fixed preference list, falling back to `libsvtav1`. |
|
||||
| `pix_fmt` | `str` | `"yuv420p"` | Output pixel format. Must be supported by the chosen codec in your FFmpeg build. |
|
||||
| `g` | `int` | `2` | GOP size — a keyframe every `g` frames. Emitted as FFmpeg option `g`. |
|
||||
| `crf` | `int` or `float` | `30` | Abstract quality value, mapped per codec (see the [mapping](#mapping-videoencoderconfig--ffmpeg-options) below). Lower → higher quality / larger output where the mapping is monotone. |
|
||||
| `preset` | `int` or `str` | `12` \* | Encoder speed preset; meaning depends on the codec. <br/>\* When unset and `vcodec=libsvtav1`, LeRobot defaults to `12`. |
|
||||
| `fast_decode` | `int` | `0` | `libsvtav1`: `0–2`, passed via `svtav1-params`. <br/>`h264` / `hevc` (software): if `>0`, sets `tune=fastdecode`. <br/>Other codecs: usually unused. |
|
||||
| `video_backend` | `str` | `"pyav"` | Only `"pyav"` is currently implemented for video encoding. |
|
||||
| `extra_options` | `dict` | `{}` | Extra FFmpeg or codec specific options merged after the structured fields above. Cannot override keys already set by those fields. |
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"observation.images.laptop": {
|
||||
"dtype": "video",
|
||||
"shape": [480, 640, 3],
|
||||
"info": {
|
||||
"video.height": 480,
|
||||
"video.width": 640,
|
||||
"video.codec": "h264",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
"video.fast_decode": 0,
|
||||
"video.video_backend": "pyav",
|
||||
"video.extra_options": { "tune": "film", "profile:v": "high", "bf": 2 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
This block is populated **once**, from the **first** episode. It assumes every
|
||||
episode in the dataset was encoded with the same `camera_encoder`. Changing
|
||||
encoder settings partway through a recording is not supported — the
|
||||
`info.json` will only reflect the parameters used for the first episode.
|
||||
</Tip>
|
||||
|
||||
---
|
||||
|
||||
## Merging datasets
|
||||
|
||||
When aggregating datasets with `merge_datasets`, video files are concatenated as-is (no re-encoding), and encoder fields in `info.json` are merged per-key:
|
||||
|
||||
- **Stream-derived fields must match** across sources: `video.codec`, `video.pix_fmt`, `video.height`, `video.width`, `video.fps`. Otherwise FFmpeg's concat demuxer fails.
|
||||
- **Encoder-tuning fields are merged loosely**: `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.extra_options`. If every source agrees, the value is kept; if not, it's set to `null` (or `{}` for `video.extra_options`) and a warning is logged.
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
parquet_path = local_path / progress_file
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
logging.warning("%s not found", progress_file)
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -576,6 +585,7 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -585,6 +595,8 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -592,7 +604,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -600,9 +612,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -627,7 +639,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -658,6 +670,15 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -670,6 +691,7 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Dataset\n",
|
||||
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
|
||||
"HF_USER = \"your_hf_username\" # `hf auth whoami` to find your username\n",
|
||||
"DATASET_NAME = \"my_so101_dataset\"\n",
|
||||
"TASK_DESCRIPTION = \"pick and place the block\"\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
@@ -291,7 +291,34 @@
|
||||
"\n",
|
||||
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
|
||||
"\n",
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details.\n",
|
||||
"\n",
|
||||
"Recently ```lerobot-rollout``` was introduced, you can [read more about it here](https://huggingface.co/docs/lerobot/main/en/il_robots?eval=Base+mode+%28no+recording%29#run-inference-and-evaluate-your-policy)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-rollout\",\n",
|
||||
" \"--strategy.type=base\",\n",
|
||||
" f\"--policy.path={POLICY_PATH}\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f'--task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--duration=60\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"if you are using the V0.5.1 release you should use ```lerobot-record``` instead of rollout"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
136
examples/omx/README.md
Normal file
136
examples/omx/README.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# OMX Follower — Cube Pick And Place Example
|
||||
|
||||
This is an example of what is possible to do with LeRobot on a physical setup.
|
||||
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
|
||||
|
||||
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
|
||||
|
||||
## Hardware
|
||||
|
||||
| Component | Value |
|
||||
| --------- | ------------------------------------ |
|
||||
| Robot | OMX Follower |
|
||||
| Cameras | 2× OpenCV cameras (wrist + top-down) |
|
||||
|
||||
## Scripts
|
||||
|
||||
| Script | Purpose |
|
||||
| ---------------------- | --------------------------------------------------------------- |
|
||||
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
|
||||
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
|
||||
|
||||
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
|
||||
|
||||
```bash
|
||||
export ROBOT_PORT=/dev/ttyACM0
|
||||
export TELEOP_PORT=/dev/ttyACM1
|
||||
export HF_USERNAME=<your_hf_username>
|
||||
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
|
||||
```
|
||||
|
||||
## Step 1 — Collect Data
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=$TELEOP_PORT \
|
||||
--teleop.id=omx_leader \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
### Bonus Auto-Collect script
|
||||
|
||||
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
|
||||
|
||||
```bash
|
||||
python -m examples.omx.record_grab \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Each episode:
|
||||
|
||||
1. The arm grabs the cube from the center of the workspace and places it at a random position.
|
||||
2. The arm returns to HOME.
|
||||
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
|
||||
|
||||
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
|
||||
|
||||
## Step 2 — Train
|
||||
|
||||
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/omx_pickandplace_act \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
|
||||
--steps=20000 \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
|
||||
|
||||
## Step 3 — Rollout
|
||||
|
||||
Use the `lerobot-rollout` CLI with base strategy:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
```
|
||||
|
||||
For continuous recording with automatic upload (sentry mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=10 \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
|
||||
```
|
||||
|
||||
## Environment Reset Utility
|
||||
|
||||
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
|
||||
|
||||
`reset_environment.py` can be run standalone to prepare the workspace:
|
||||
|
||||
```bash
|
||||
# Grab cube + place it at a random position on the left side
|
||||
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
|
||||
```
|
||||
|
||||
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
|
||||
422
examples/omx/record_grab.py
Normal file
422
examples/omx/record_grab.py
Normal file
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-record grab episodes for the OMX robot arm.
|
||||
|
||||
Each episode cycle:
|
||||
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
|
||||
2. HOME — return arm to home with gripper open
|
||||
3. record_grab — execute a targeted grab to the stored position while recording
|
||||
observations + actions to a LeRobotDataset
|
||||
|
||||
Usage (run from repo root):
|
||||
python -m examples.omx.record_grab \\
|
||||
--robot.type=omx_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--robot.id=omx_follower \\
|
||||
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
|
||||
--dataset.repo_id=<hf_username>/<dataset_name> \\
|
||||
--dataset.root=data/omx_grab \\
|
||||
--dataset.num_episodes=50 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import CameraConfig # noqa: F401
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
)
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots import RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.omx_follower import OmxFollower
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from .reset_environment import (
|
||||
APPROACH_SPEED,
|
||||
GRIPPER_CLOSE_POS,
|
||||
HOME_POSE,
|
||||
PUSH_END_ELBOW_FLEX,
|
||||
PUSH_END_SHOULDER_LIFT,
|
||||
PUSH_START_ELBOW_FLEX,
|
||||
PUSH_START_SHOULDER_LIFT,
|
||||
array_to_pose,
|
||||
grab_cube,
|
||||
horizontal_wrist_flex,
|
||||
move_to_pose,
|
||||
place_cube,
|
||||
pose_to_array,
|
||||
)
|
||||
|
||||
# ── Grab-episode motion parameters ────────────────────────────────────────────
|
||||
|
||||
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
|
||||
GRAB_RAISE_SL_OFFSET = 20.0
|
||||
GRAB_LOWER_SPEED = 20.0
|
||||
RECORD_SPEED = 30.0
|
||||
|
||||
# Pose the arm travels to after closing the gripper (cube held).
|
||||
GRAB_CARRY_POSE = {
|
||||
"shoulder_pan.pos": -23.0,
|
||||
"shoulder_lift.pos": 5.0,
|
||||
"elbow_flex.pos": 18.0,
|
||||
"wrist_flex.pos": -14.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
|
||||
# Cube-approach and carry poses are never jittered to preserve precision.
|
||||
_JITTER_LIMITS: dict[str, float] = {
|
||||
"shoulder_pan.pos": 5.0,
|
||||
"shoulder_lift.pos": 4.0,
|
||||
"elbow_flex.pos": 4.0,
|
||||
"wrist_flex.pos": 3.0,
|
||||
"wrist_roll.pos": 2.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
|
||||
"""Return a copy of pose with independent per-joint random perturbations."""
|
||||
return {
|
||||
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
|
||||
}
|
||||
|
||||
|
||||
def _random_stuck_pose(rng: np.random.Generator) -> dict:
|
||||
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
|
||||
|
||||
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
|
||||
table-safe envelope across the full sl range:
|
||||
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
|
||||
sl= 0 → ef ∈ [-25, 25] (mid reach)
|
||||
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
|
||||
wrist_flex is randomly offset from the horizontal value.
|
||||
"""
|
||||
pan = float(rng.uniform(-5.0, 35.0))
|
||||
sl = float(rng.uniform(-50.0, 30.0))
|
||||
|
||||
if sl <= 0.0:
|
||||
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
|
||||
ef_lo = alpha * -25.0 # 0 → -25
|
||||
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
|
||||
else:
|
||||
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
|
||||
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
|
||||
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
|
||||
|
||||
ef = float(rng.uniform(ef_lo, ef_hi))
|
||||
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf,
|
||||
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmxRecordGrabConfig:
|
||||
robot: RobotConfig
|
||||
dataset: DatasetRecordConfig
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Fraction of episodes that start from a random stuck pose (gripper closed) to
|
||||
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
|
||||
recovery_prob: float = 0.5
|
||||
|
||||
|
||||
def record_episode_spline(
|
||||
robot: OmxFollower,
|
||||
waypoints: list[dict],
|
||||
speeds: list[float],
|
||||
dataset: LeRobotDataset,
|
||||
task: str,
|
||||
) -> None:
|
||||
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
|
||||
|
||||
Segment durations are parameterized from the maximum absolute joint delta
|
||||
between consecutive waypoints divided by the requested segment speed,
|
||||
producing non-uniform timing in joint space. Interior tangents are derived
|
||||
from the adjacent per-segment velocities, with clamped (zero-velocity)
|
||||
endpoints so the arm starts and stops smoothly. Each segment is cubic
|
||||
Hermite, giving C1 continuity at every waypoint.
|
||||
"""
|
||||
pts = [pose_to_array(w) for w in waypoints]
|
||||
n = len(pts)
|
||||
|
||||
# Steps and duration per segment
|
||||
n_steps_list = []
|
||||
timestamps = []
|
||||
for i in range(n - 1):
|
||||
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
|
||||
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
|
||||
n_steps_list.append(ns)
|
||||
timestamps.append(ns / dataset.fps)
|
||||
|
||||
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
|
||||
vels = [np.zeros_like(pts[0])]
|
||||
for i in range(1, n - 1):
|
||||
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
|
||||
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
|
||||
vels.append(0.5 * (v_prev + v_next))
|
||||
vels.append(np.zeros_like(pts[0]))
|
||||
|
||||
dt = 1.0 / dataset.fps
|
||||
for seg in range(n - 1):
|
||||
ns = n_steps_list[seg]
|
||||
if ns == 0:
|
||||
continue
|
||||
p0, p1 = pts[seg], pts[seg + 1]
|
||||
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
|
||||
m0 = vels[seg] * timestamps[seg]
|
||||
m1 = vels[seg + 1] * timestamps[seg]
|
||||
|
||||
for step in range(1, ns + 1):
|
||||
t = step / ns
|
||||
h00 = 2 * t**3 - 3 * t**2 + 1
|
||||
h10 = t**3 - 2 * t**2 + t
|
||||
h01 = -2 * t**3 + 3 * t**2
|
||||
h11 = t**3 - t**2
|
||||
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||
|
||||
action = array_to_pose(commanded)
|
||||
robot.send_action(action)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
def record_grab_episode(
|
||||
robot: OmxFollower,
|
||||
dataset: LeRobotDataset,
|
||||
pan: float,
|
||||
t: float,
|
||||
task: str,
|
||||
recovery_start: bool = False,
|
||||
) -> None:
|
||||
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
|
||||
|
||||
Normal sequence (initial HOME move is NOT recorded):
|
||||
HOME → raised approach above cube → lower → close gripper
|
||||
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
|
||||
|
||||
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
|
||||
(gripper closed) without recording, then recording begins from there:
|
||||
stuck_pose → raised approach above cube → [normal grab sequence from there]
|
||||
|
||||
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
|
||||
"""
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
sl_raised = sl - GRAB_RAISE_SL_OFFSET
|
||||
wf_horizontal = horizontal_wrist_flex(sl, ef)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
if recovery_start:
|
||||
stuck_pose = _random_stuck_pose(rng)
|
||||
logger.info(f"Recovery start: {stuck_pose}")
|
||||
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
|
||||
first_waypoints = [stuck_pose]
|
||||
first_speeds = []
|
||||
else:
|
||||
jittery_start = _jitter_pose(HOME_POSE, rng)
|
||||
move_to_pose(robot, jittery_start, APPROACH_SPEED)
|
||||
first_waypoints = [jittery_start]
|
||||
first_speeds = []
|
||||
|
||||
waypoints = first_waypoints + [
|
||||
{ # raised approach: arm above cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # lower onto cube — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # close gripper — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
_jitter_pose(
|
||||
{ # raise with cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
_jitter_pose(
|
||||
{ # retract: fold arm toward HOME before sweeping to carry zone
|
||||
"shoulder_pan.pos": pan * 0.25,
|
||||
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
|
||||
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
GRAB_CARRY_POSE, # no jitter: target drop zone
|
||||
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
|
||||
HOME_POSE,
|
||||
]
|
||||
speeds = first_speeds + [
|
||||
RECORD_SPEED, # (HOME →) raised approach
|
||||
GRAB_LOWER_SPEED, # raised approach → lower
|
||||
GRAB_LOWER_SPEED, # lower → close gripper
|
||||
RECORD_SPEED, # close gripper → raise
|
||||
RECORD_SPEED, # raise → retract
|
||||
RECORD_SPEED, # retract → carry pose
|
||||
RECORD_SPEED, # carry pose → drop
|
||||
RECORD_SPEED, # drop → HOME
|
||||
]
|
||||
|
||||
record_episode_spline(robot, waypoints, speeds, dataset, task)
|
||||
|
||||
# Dwell at HOME for ~0.5 s before next episode
|
||||
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
|
||||
dt = 1.0 / dataset.fps
|
||||
for _ in range(int(dataset.fps * 0.5)):
|
||||
robot.send_action(HOME_POSE)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
dataset.add_frame({**obs_frame, **home_action, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger.info(pformat(cfg))
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
use_videos = cfg.dataset.video
|
||||
|
||||
teleop_action_processor, _, robot_obs_processor = make_default_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_obs_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
)
|
||||
|
||||
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
|
||||
dataset = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset.resume(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
else:
|
||||
cfg.dataset.stamp_repo_id()
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=use_videos,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
with VideoEncodingManager(dataset):
|
||||
for episode_idx in range(cfg.dataset.num_episodes):
|
||||
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
|
||||
|
||||
logger.info("Step 1: grabbing and placing cube...")
|
||||
grab_cube(robot)
|
||||
pan, t = place_cube(robot)
|
||||
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
|
||||
|
||||
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
|
||||
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
|
||||
record_grab_episode(
|
||||
robot,
|
||||
dataset,
|
||||
pan,
|
||||
t,
|
||||
cfg.dataset.single_task,
|
||||
recovery_start=recovery_start,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
logger.info(f"Episode {episode_idx + 1} saved.")
|
||||
|
||||
finally:
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
record_grab()
|
||||
267
examples/omx/reset_environment.py
Normal file
267
examples/omx/reset_environment.py
Normal file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-reset and cube-grab utility for the OMX robot arm.
|
||||
|
||||
Provides:
|
||||
- grab_cube(robot): sweep workspace, center cube, close gripper
|
||||
- place_cube(robot): carry cube to a random position, release
|
||||
|
||||
Standalone usage (run from repo root):
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
|
||||
|
||||
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
|
||||
|
||||
To read current joint values for calibration, add after robot.connect():
|
||||
obs = robot.get_observation()
|
||||
print({k: round(obs[k], 1) for k in JOINT_NAMES})
|
||||
robot.disconnect(); raise SystemExit
|
||||
|
||||
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
|
||||
Linear interpolation preserves this constraint between any two poses that satisfy it.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Poses ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
HOME_POSE = {
|
||||
"shoulder_pan.pos": 0.0,
|
||||
"shoulder_lift.pos": -50.0,
|
||||
"elbow_flex.pos": 50.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
SWEEP_WAYPOINTS = [
|
||||
{
|
||||
"shoulder_pan.pos": -60.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -20.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": -30.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": 20.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -55.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
]
|
||||
|
||||
# ── Motion parameters ─────────────────────────────────────────────────────────
|
||||
|
||||
CONTROL_HZ = 30
|
||||
APPROACH_SPEED = 50.0
|
||||
SWEEP_SPEED = 40.0
|
||||
|
||||
# ── Grab-sequence parameters ──────────────────────────────────────────────────
|
||||
|
||||
GRAB_PAN = 0.0
|
||||
SWEEP_LEFT_PAN = -60.0
|
||||
SWEEP_RIGHT_PAN = 60.0
|
||||
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
|
||||
SWEEP_END_PAN_RANGE = (15.0, 20.0)
|
||||
|
||||
SWEEP_LOW_SHOULDER_LIFT = 50.0
|
||||
SWEEP_LOW_ELBOW_FLEX_START = -60.0
|
||||
SWEEP_LOW_ELBOW_FLEX_END = -55.0
|
||||
|
||||
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
|
||||
|
||||
PUSH_START_SHOULDER_LIFT = 0.0
|
||||
PUSH_START_ELBOW_FLEX = 45.0
|
||||
PUSH_END_SHOULDER_LIFT = 50.0
|
||||
PUSH_END_ELBOW_FLEX = -50.0
|
||||
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
|
||||
# Does not affect the grab-target interpolation in record_grab.py.
|
||||
PUSH_RAISE_OFFSET = 5.0
|
||||
|
||||
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
|
||||
GRIPPER_CLOSE_POS = 50.0
|
||||
|
||||
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
|
||||
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
|
||||
|
||||
JOINT_NAMES = [
|
||||
"shoulder_pan.pos",
|
||||
"shoulder_lift.pos",
|
||||
"elbow_flex.pos",
|
||||
"wrist_flex.pos",
|
||||
"wrist_roll.pos",
|
||||
"gripper.pos",
|
||||
]
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def pose_to_array(pose: dict) -> np.ndarray:
|
||||
return np.array([pose[k] for k in JOINT_NAMES])
|
||||
|
||||
|
||||
def array_to_pose(arr: np.ndarray) -> dict:
|
||||
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
|
||||
|
||||
|
||||
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
|
||||
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
|
||||
|
||||
|
||||
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
|
||||
sl = SWEEP_LOW_SHOULDER_LIFT
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
|
||||
def _high_sweep_pose(pan: float) -> dict:
|
||||
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
|
||||
|
||||
|
||||
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": shoulder_lift,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": gripper,
|
||||
}
|
||||
|
||||
|
||||
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
|
||||
"""Interpolate from current position to target at the given speed (units/s)."""
|
||||
obs = robot.get_observation()
|
||||
current = np.array([obs[k] for k in JOINT_NAMES])
|
||||
goal = pose_to_array(target)
|
||||
|
||||
max_distance = float(np.max(np.abs(goal - current)))
|
||||
if max_distance < 0.5:
|
||||
return
|
||||
|
||||
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
|
||||
dt = 1.0 / CONTROL_HZ
|
||||
for step in range(1, n_steps + 1):
|
||||
t = step / n_steps
|
||||
robot.send_action(array_to_pose(current + t * (goal - current)))
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
# ── Sequences ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def grab_cube(robot: Robot) -> None:
|
||||
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
for pan, end_pan in [
|
||||
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
|
||||
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
|
||||
]:
|
||||
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
|
||||
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
logger.info("Extending to push cube into gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
|
||||
SWEEP_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Closing gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Grab complete.")
|
||||
|
||||
|
||||
def place_cube(robot: Robot) -> tuple[float, float]:
|
||||
"""Carry the cube (gripper closed) to a random position on the left side, then release.
|
||||
|
||||
Returns:
|
||||
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
|
||||
"""
|
||||
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
|
||||
t = float(np.random.uniform(*PLACE_REACH_RANGE))
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
|
||||
|
||||
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
logger.info("Place complete.")
|
||||
return pan, t
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
|
||||
parser.add_argument("--port", default="/dev/ttyACM1")
|
||||
parser.add_argument("--robot_id", default="omx_follower")
|
||||
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
try:
|
||||
if args.mode == "grab":
|
||||
grab_cube(robot)
|
||||
elif args.mode == "grab_and_place":
|
||||
grab_cube(robot)
|
||||
place_cube(robot)
|
||||
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,13 +4,13 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -28,7 +28,7 @@ def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
policy_learner: GaussianActorPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
@@ -40,8 +40,9 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -83,24 +84,26 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
stats = algorithm.update(batch_iter())
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -113,7 +116,7 @@ def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
policy_actor: GaussianActorPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
@@ -144,15 +147,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
@@ -261,14 +264,14 @@ def main():
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
policy_cfg = GaussianActorConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
|
||||
dependencies = [
|
||||
# Core ML
|
||||
"torch>=2.7,<2.11.0",
|
||||
"torchvision>=0.22.0,<0.26.0",
|
||||
"torch>=2.7,<2.12.0",
|
||||
"torchvision>=0.22.0,<0.27.0",
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
@@ -95,11 +95,22 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
||||
|
||||
# NOTE: torchcodec wheel availability matrix (PyPI):
|
||||
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
|
||||
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
|
||||
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
|
||||
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
|
||||
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
|
||||
|
||||
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
|
||||
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
||||
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
|
||||
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
]
|
||||
training = [
|
||||
@@ -127,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
@@ -140,6 +153,8 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
|
||||
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
@@ -163,6 +178,9 @@ unitree_g1 = [
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
@@ -194,7 +212,8 @@ groot = [
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -248,6 +267,7 @@ all = [
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[openarms]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[rebot]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[diffusion]",
|
||||
@@ -292,6 +312,20 @@ lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
|
||||
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
|
||||
# uv pip install --force-reinstall torch torchvision \
|
||||
# --index-url https://download.pytorch.org/whl/cu130
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
|
||||
@@ -199,12 +199,13 @@ class OpenCVCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
set_fourcc_after_size_and_fps = platform.system() == "Windows"
|
||||
if self.config.fourcc is not None and not set_fourcc_after_size_and_fps:
|
||||
self._validate_fourcc()
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
|
||||
@@ -222,6 +223,11 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
if self.config.fourcc is not None and set_fourcc_after_size_and_fps:
|
||||
# On Windows with DSHOW, changing the resolution can silently override the FOURCC setting.
|
||||
# Set FOURCC last to make sure the requested pixel format is actually enforced.
|
||||
self._validate_fourcc()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -24,6 +24,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -31,6 +32,12 @@ from .types import (
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
@@ -43,7 +50,16 @@ __all__ = [
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
]
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
@@ -55,10 +57,9 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
@@ -117,3 +117,9 @@ class PeftConfig:
|
||||
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
||||
# fine-tuning.
|
||||
r: int = 16
|
||||
|
||||
# Alpha parameter for LoRA scaling (scaling = lora_alpha / r).
|
||||
# In general, a higher alpha means stronger adaptation signal.
|
||||
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||
# Common values are r (alpha == rank) or 2*r.
|
||||
lora_alpha: int | None = None
|
||||
|
||||
@@ -18,8 +18,8 @@ from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
|
||||
from . import parser
|
||||
from .default import EvalConfig
|
||||
from .policies import PreTrainedConfig
|
||||
|
||||
@@ -46,8 +46,11 @@ class EvalPipelineConfig:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import pkgutil
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
@@ -24,6 +26,7 @@ from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
@@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object])
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
# Storage for path args extracted from YAML/JSON config files, so that
|
||||
# get_path_arg() can find them even when they weren't passed via CLI.
|
||||
_config_path_args: dict[str, str] = {}
|
||||
|
||||
# Storage for non-path YAML overrides so validate() can pass them to from_pretrained.
|
||||
_config_yaml_overrides: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]:
|
||||
"""Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"])."""
|
||||
args = []
|
||||
for key, value in d.items():
|
||||
if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY):
|
||||
continue
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, bool):
|
||||
value = str(value).lower()
|
||||
if isinstance(value, dict):
|
||||
args.extend(_flatten_to_cli_args(value, full_key))
|
||||
elif value is not None and not isinstance(value, list):
|
||||
args.append(f"--{full_key}={value}")
|
||||
return args
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
@@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None:
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
result = parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
if result is None:
|
||||
result = _config_path_args.get(field_name)
|
||||
return result
|
||||
|
||||
|
||||
def get_yaml_overrides(field_name: str) -> list[str]:
|
||||
return _config_yaml_overrides.get(field_name, [])
|
||||
|
||||
|
||||
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
@@ -192,6 +225,51 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str:
|
||||
"""Extract `path` fields from a YAML/JSON config before draccus processes it.
|
||||
|
||||
When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config,
|
||||
draccus will fail because ``path`` is not a valid field on policy config classes.
|
||||
This function extracts those path values, stores them in ``_config_path_args`` for
|
||||
later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path.
|
||||
"""
|
||||
config_file = Path(config_path)
|
||||
suffix = config_file.suffix.lower()
|
||||
|
||||
if suffix in (".yaml", ".yml"):
|
||||
with open(config_file) as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
elif suffix == ".json":
|
||||
with open(config_file) as f:
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
return config_path
|
||||
|
||||
if not isinstance(config_data, dict):
|
||||
return config_path
|
||||
|
||||
modified = False
|
||||
for field in path_fields:
|
||||
if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]:
|
||||
_config_path_args[field] = str(config_data[field].pop(PATH_KEY))
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
return config_path
|
||||
|
||||
# Write cleaned config to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
||||
if suffix in (".yaml", ".yml"):
|
||||
yaml.dump(config_data, tmp, default_flow_style=False)
|
||||
else:
|
||||
json.dump(config_data, tmp, indent=2)
|
||||
return tmp.name
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
@@ -225,11 +303,20 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
cli_args = filter_path_args(path_fields, cli_args)
|
||||
# Also extract path fields from the YAML/JSON config file
|
||||
if config_path_cli:
|
||||
config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields)
|
||||
if has_method(argtype, "from_pretrained") and config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
if config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype,
|
||||
config_path=config_path_cli or config_path,
|
||||
args=cli_args,
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
206
src/lerobot/configs/recipe.py
Normal file
206
src/lerobot/configs/recipe.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
@@ -27,12 +27,13 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .types import PolicyFeature
|
||||
|
||||
T = TypeVar("T", bound="RewardModelConfig")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,9 +90,9 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
def get_optimizer_preset(self) -> OptimizerConfig | None:
|
||||
"""Default optimizer for this reward model, or ``None`` for zero-shot models."""
|
||||
return None
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
@@ -25,11 +25,11 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
@@ -144,8 +144,11 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -256,7 +259,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
if config_file is not None:
|
||||
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
||||
if config_file is not None and config_file.endswith(".json"):
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
@@ -267,10 +272,3 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
235
src/lerobot/configs/video.py
Normal file
235
src/lerobot/configs/video.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Note: We subclass str so that serialization is straightforward
|
||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||
|
||||
"""Video encoder configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and the chosen video backend.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_VIDEO_CODECS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
# ``vcodec``` and ``pix_fmt`` are derived from the video stream directly.
|
||||
VIDEO_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset(
|
||||
{"g", "crf", "preset", "fast_decode", "extra_options", "video_backend"}
|
||||
)
|
||||
VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
"""Video encoder configuration.
|
||||
|
||||
Attributes:
|
||||
vcodec: Video encoder name. ``"auto"`` is resolved during
|
||||
construction (HW encoder if available, else ``libsvtav1``).
|
||||
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
|
||||
g: GOP size (keyframe interval).
|
||||
crf: Quality level — mapped to the native quality parameter of the
|
||||
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
|
||||
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
|
||||
preset: Speed/quality preset. Accepted type is per-codec.
|
||||
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
|
||||
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
|
||||
set ``tune=fastdecode``. Ignored for other codecs.
|
||||
video_backend: Python to be used for encoding. Only ``"pyav"``
|
||||
is currently supported.
|
||||
extra_options: Free-form dictionary of additional video encoder options
|
||||
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
|
||||
"""
|
||||
|
||||
vcodec: str = "libsvtav1" # TODO(CarolinePascal): rename to codec ?
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int | None = 2
|
||||
crf: int | float | None = 30
|
||||
preset: int | str | None = None
|
||||
fast_decode: int = 0
|
||||
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
|
||||
# two backends (encoding and decoding).
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
if self.preset is None and self.vcodec == "libsvtav1":
|
||||
self.preset = LIBSVTAV1_DEFAULT_PRESET
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
|
||||
for src_key, dst_field in (("video.codec", "vcodec"), ("video.pix_fmt", "pix_fmt")):
|
||||
value = video_info.get(src_key)
|
||||
if value is not None:
|
||||
kwargs[dst_field] = value
|
||||
|
||||
for field_name in VIDEO_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{field_name}")
|
||||
if value is None:
|
||||
continue
|
||||
# Persisted as ``{}`` after merges with disagreeing sources — treat as default.
|
||||
if field_name == "extra_options" and not value:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
|
||||
Args:
|
||||
encoders: List of encoder names to detect. If a string, it is converted to a list.
|
||||
Returns:
|
||||
List of available encoder names. If the video backend is not "pyav", returns an empty list.
|
||||
"""
|
||||
if self.video_backend == "pyav":
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import detect_available_encoders_pyav
|
||||
|
||||
return detect_available_encoders_pyav(encoders)
|
||||
return []
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the video encoder configuration."""
|
||||
if self.video_backend == "pyav":
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
|
||||
For ``"auto"``, the first hardware encoder in the preference list that is available is chosen; if none are available, ``libsvtav1`` is used. If the
|
||||
resolved codec (explicit or after auto-selection) is not available, raises ``ValueError``.
|
||||
|
||||
Stream-derived canonical codec names listed in :data:`VIDEO_CODECS_ALIASES` are
|
||||
rewritten to their corresponding encoder name (e.g. ``"av1"`` → ``"libsvtav1"``).
|
||||
"""
|
||||
self.vcodec = VIDEO_CODECS_ALIASES.get(self.vcodec, self.vcodec)
|
||||
if self.vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if self.vcodec == "auto":
|
||||
available = self.detect_available_encoders(HW_VIDEO_CODECS)
|
||||
for encoder in HW_VIDEO_CODECS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
self.vcodec = encoder
|
||||
return
|
||||
logger.warning("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.detect_available_encoders(self.vcodec):
|
||||
logger.info(f"Using video codec: {self.vcodec}")
|
||||
return
|
||||
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
|
||||
|
||||
def get_codec_options(
|
||||
self, encoder_threads: int | None = None, as_strings: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Translate the tuning fields to codec-specific options.
|
||||
|
||||
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
|
||||
|
||||
Args:
|
||||
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
|
||||
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
|
||||
For h264/hevc, this is mapped to ``threads``.
|
||||
Hardware encoders ignore this parameter.
|
||||
as_strings: If ``True``, casts values to strings.
|
||||
"""
|
||||
opts: dict[str, Any] = {}
|
||||
|
||||
def set_if(key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
opts[key] = value if not as_strings else str(value)
|
||||
|
||||
# GOP size is not a codec-specific option, so it is always set.
|
||||
set_if("g", self.g)
|
||||
|
||||
if self.vcodec == "libsvtav1":
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
svtav1_parts: list[str] = []
|
||||
if self.fast_decode is not None:
|
||||
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
|
||||
if encoder_threads is not None:
|
||||
svtav1_parts.append(f"lp={encoder_threads}")
|
||||
if svtav1_parts:
|
||||
opts["svtav1-params"] = ":".join(svtav1_parts)
|
||||
elif self.vcodec in ("h264", "hevc"):
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
if self.fast_decode:
|
||||
opts["tune"] = "fastdecode"
|
||||
set_if("threads", encoder_threads)
|
||||
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
if self.crf is not None:
|
||||
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
|
||||
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
opts["rc"] = 0
|
||||
set_if("qp", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "h264_vaapi":
|
||||
set_if("qp", self.crf)
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
|
||||
# Extra options are merged last but never override structured fields (values are kept as given).
|
||||
for k, v in self.extra_options.items():
|
||||
if k not in opts:
|
||||
set_if(k, v)
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
@@ -31,15 +31,25 @@ from .dataset_tools import (
|
||||
modify_features,
|
||||
modify_tasks,
|
||||
recompute_stats,
|
||||
reencode_dataset,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
@@ -53,12 +63,19 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"check_video_encoder_parameters_pyav",
|
||||
"detect_available_encoders_pyav",
|
||||
"add_features",
|
||||
"aggregate_datasets",
|
||||
"aggregate_pipeline_dataset_features",
|
||||
@@ -66,6 +83,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
@@ -74,6 +92,7 @@ __all__ = [
|
||||
"modify_features",
|
||||
"modify_tasks",
|
||||
"recompute_stats",
|
||||
"reencode_dataset",
|
||||
"remove_feature",
|
||||
"resolve_delta_timestamps",
|
||||
"safe_stop_image_writer",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -23,9 +24,11 @@ import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
|
||||
from .compute_stats import aggregate_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .feature_utils import get_hf_features_from_features
|
||||
from .feature_utils import features_equal_for_merge, get_hf_features_from_features
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
@@ -46,11 +49,54 @@ from .utils import (
|
||||
from .video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
def merge_video_feature_info_for_aggregate(all_metadata: list[LeRobotDatasetMetadata]) -> dict[str, dict]:
|
||||
"""Create a merged video feature info dictionary for aggregation. The video encoder info is merged field-by-field: each key is kept only when every source agrees; otherwise that key is set to ``null`` (or ``{}`` for ``video.extra_options``) and a warning is logged.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to merge.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of merged video feature info.
|
||||
"""
|
||||
merged_info = copy.deepcopy(all_metadata[0].features)
|
||||
video_keys = [k for k in merged_info if merged_info[k].get("dtype") == "video"]
|
||||
|
||||
for vk in video_keys:
|
||||
video_infos = [m.features.get(vk, {}).get("info") or {} for m in all_metadata]
|
||||
base_video_info = video_infos[0]
|
||||
|
||||
merged_encoder_info: dict = {}
|
||||
fallback_keys: list[str] = []
|
||||
for info_key in VIDEO_ENCODER_INFO_KEYS:
|
||||
values = [info.get(info_key, None) for info in video_infos]
|
||||
first_value = values[0]
|
||||
all_match = all(v == first_value for v in values[1:])
|
||||
|
||||
if all_match:
|
||||
merged_encoder_info[info_key] = first_value
|
||||
else:
|
||||
fallback_keys.append(info_key)
|
||||
merged_encoder_info[info_key] = {} if info_key == "video.extra_options" else None
|
||||
|
||||
if fallback_keys:
|
||||
logging.warning(
|
||||
f"Merging heterogeneous or incomplete video encoder metadata for feature {vk}. "
|
||||
f"Setting these keys to null: {fallback_keys}.",
|
||||
)
|
||||
|
||||
merged_info[vk]["info"] = {**base_video_info, **merged_encoder_info}
|
||||
# TODO(CarolinePascal): make this variable once we have support for other video backends.
|
||||
merged_info[vk]["info"]["video.video_backend"] = "pyav"
|
||||
|
||||
return merged_info
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
"""Validates that all dataset metadata have consistent properties.
|
||||
|
||||
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
||||
compatibility when aggregating them into a single dataset.
|
||||
Video encoder info is not considered for validation but is merged during aggregation in ``merge_video_feature_info_for_aggregate``.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
||||
@@ -74,7 +120,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
if not features_equal_for_merge(features, meta.features):
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
@@ -274,7 +320,8 @@ def aggregate_datasets(
|
||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
fps, robot_type, _ = validate_all_metadata(all_metadata)
|
||||
features = merge_video_feature_info_for_aggregate(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
@@ -332,7 +379,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -414,9 +460,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
# TODO(CarolinePascal): Move the check before the loop to avoid failing in the middle + add possibility to re-encode the video if the check fails
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
compatibility_check=True,
|
||||
)
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -23,6 +24,7 @@ import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
||||
from lerobot.utils.feature_utils import _validate_feature_names
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
@@ -34,12 +36,12 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
check_version_compatibility,
|
||||
@@ -175,7 +177,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -189,6 +190,29 @@ class LeRobotDatasetMetadata:
|
||||
if self.episodes is None:
|
||||
self._load_metadata()
|
||||
|
||||
def filter_episodes(
|
||||
self,
|
||||
predicate: Callable[[dict], bool],
|
||||
candidates: list[int] | None = None,
|
||||
) -> list[int]:
|
||||
"""Filter episodes whose metadata satisfies a given predicate.
|
||||
|
||||
Args:
|
||||
predicate: Predicate over per-episode metadata rows used to select episodes.
|
||||
candidates: Optional list of episode indices to restrict evaluation to.
|
||||
|
||||
Returns:
|
||||
List of sorted episode indices that satisfy the predicate.
|
||||
"""
|
||||
self.ensure_readable()
|
||||
if candidates is not None:
|
||||
candidate_set = set(candidates)
|
||||
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
|
||||
else:
|
||||
combined = predicate
|
||||
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
|
||||
return sorted(int(idx) for idx in filtered["episode_index"])
|
||||
|
||||
def _pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
@@ -318,6 +342,49 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def has_language_columns(self) -> bool:
|
||||
"""Return ``True`` if the dataset declares any language column.
|
||||
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@tools.setter
|
||||
def tools(self, value: list[dict] | None) -> None:
|
||||
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
|
||||
|
||||
Writes ``value`` into the on-disk ``info.json`` (or clears the
|
||||
``tools`` key when ``value`` is ``None`` or empty), then reloads
|
||||
``self.info`` so the in-memory metadata matches what's on disk.
|
||||
Saves callers from hand-editing ``info.json`` and re-instantiating
|
||||
the metadata object.
|
||||
"""
|
||||
self.info.tools = [dict(t) for t in value] if value else None
|
||||
write_info(self.info, self.root)
|
||||
self.info = load_info(self.root)
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -510,10 +577,23 @@ class LeRobotDatasetMetadata:
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
@@ -522,7 +602,7 @@ class LeRobotDatasetMetadata:
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -633,7 +713,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -295,9 +295,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -26,7 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -36,6 +36,7 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
@@ -60,9 +61,14 @@ from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
VIDEO_DIR,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import encode_video_frames, get_video_info
|
||||
from .video_utils import (
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -95,6 +101,11 @@ def delete_episodes(
|
||||
) -> LeRobotDataset:
|
||||
"""Delete episodes from a LeRobotDataset and create a new dataset.
|
||||
|
||||
Video segments that need re-encoding (because the source file mixes kept and
|
||||
deleted episodes) are re-encoded with the source dataset's existing encoder
|
||||
settings — read back from ``meta/info.json`` — so the output dataset stays
|
||||
consistent with its own metadata.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
@@ -157,6 +168,11 @@ def split_dataset(
|
||||
) -> dict[str, LeRobotDataset]:
|
||||
"""Split a LeRobotDataset into multiple smaller datasets.
|
||||
|
||||
Video segments that need re-encoding (because the source file mixes episodes
|
||||
that fall into different splits) are re-encoded with the source dataset's
|
||||
existing encoder settings — read back from ``meta/info.json`` — so each
|
||||
output split stays consistent with its own metadata.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
@@ -578,8 +594,7 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -593,8 +608,7 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
camera_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
"""
|
||||
from fractions import Fraction
|
||||
|
||||
@@ -619,12 +633,13 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
v_out = out.add_stream(vcodec, rate=fps_fraction)
|
||||
codec_options = camera_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = pix_fmt
|
||||
v_out.pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -687,14 +702,14 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> dict[int, dict]:
|
||||
"""Copy and filter video files, only re-encoding files with deleted episodes.
|
||||
|
||||
For video files that only contain kept episodes, we copy them directly.
|
||||
For files with mixed kept/deleted episodes, we use PyAV filters to efficiently
|
||||
re-encode only the desired segments.
|
||||
re-encode only the desired segments. The encoder used for re-encoding is
|
||||
derived per video key from the source dataset's ``meta/info.json`` so the
|
||||
destination metadata keeps describing the videos accurately.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset to copy from
|
||||
@@ -711,6 +726,9 @@ def _copy_and_reindex_videos(
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
logging.info(f"Processing videos for {video_key}")
|
||||
camera_encoder = VideoEncoderConfig.from_video_info(
|
||||
src_dataset.meta.info.features.get(video_key, {}).get("info")
|
||||
)
|
||||
|
||||
if dst_meta.video_path is None:
|
||||
raise ValueError("Destination metadata has no video_path defined")
|
||||
@@ -792,8 +810,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
vcodec,
|
||||
pix_fmt,
|
||||
camera_encoder,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1264,11 +1281,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1282,11 +1295,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
camera_encoder: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1322,11 +1331,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1644,11 +1649,7 @@ def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1663,11 +1664,8 @@ def convert_image_to_video_dataset(
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
camera_encoder: Video encoder settings
|
||||
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
@@ -1676,6 +1674,9 @@ def convert_image_to_video_dataset(
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
@@ -1699,7 +1700,10 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
|
||||
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
|
||||
)
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1769,11 +1773,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1815,11 +1815,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1865,7 +1861,9 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder=camera_encoder
|
||||
)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
@@ -1888,3 +1886,83 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
return video_path
|
||||
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Re-encode every video in a dataset with a new set of encoding parameters.
|
||||
|
||||
Videos are re-encoded in-place and the video information in ``info.json`` is refreshed.
|
||||
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
sequential (no multiprocessing); ``1+`` spawns a
|
||||
:class:`~concurrent.futures.ProcessPoolExecutor`.
|
||||
|
||||
Returns:
|
||||
The same :class:`LeRobotDataset` instance with its metadata updated
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_paths_list = []
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_paths_list) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
for future in tqdm(
|
||||
as_completed(futures),
|
||||
total=len(futures),
|
||||
desc="Re-encoding videos",
|
||||
):
|
||||
future.result()
|
||||
else:
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -31,6 +31,8 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .feature_utils import (
|
||||
@@ -65,14 +67,19 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
@@ -89,20 +96,22 @@ class DatasetWriter:
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
vcodec: str,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoding config.
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
@@ -111,7 +120,7 @@ class DatasetWriter:
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._vcodec = vcodec
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -241,7 +250,14 @@ class DatasetWriter:
|
||||
for key, ft in self._meta.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
stacked_values = np.stack(episode_buffer[key])
|
||||
|
||||
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
|
||||
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
|
||||
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
|
||||
stacked_values = stacked_values.reshape(episode_length)
|
||||
|
||||
episode_buffer[key] = stacked_values
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
@@ -284,7 +300,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._vcodec,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -495,7 +511,7 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key)
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -564,7 +580,12 @@ class DatasetWriter:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
def close_writer(self) -> None:
|
||||
|
||||
@@ -13,15 +13,23 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -46,7 +54,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -108,6 +122,41 @@ def create_empty_dataset_info(
|
||||
)
|
||||
|
||||
|
||||
def features_equal_for_merge(features_a: dict[str, dict], features_b: dict[str, dict]) -> bool:
|
||||
"""Return whether two LeRobotDatasetMetadata ``features`` dicts are compatible for aggregation.
|
||||
|
||||
For video features, keys under ``info`` related to video encoding parameters are ignored during
|
||||
comparison as they do not prevent aggregation.
|
||||
"""
|
||||
|
||||
def _without_encoder_info_keys(feature: dict) -> dict:
|
||||
filtered = dict(feature)
|
||||
filtered_info = filtered.get("info")
|
||||
if isinstance(filtered_info, dict):
|
||||
filtered["info"] = {
|
||||
info_key: info_value
|
||||
for info_key, info_value in filtered_info.items()
|
||||
if info_key not in VIDEO_ENCODER_INFO_KEYS
|
||||
}
|
||||
return filtered
|
||||
|
||||
if set(features_a) != set(features_b):
|
||||
return False
|
||||
for key in features_a:
|
||||
fa_key = features_a[key]
|
||||
fb_key = features_b[key]
|
||||
if fa_key.get("dtype") != fb_key.get("dtype"):
|
||||
return False
|
||||
if fa_key.get("dtype") != "video":
|
||||
if fa_key != fb_key:
|
||||
return False
|
||||
continue
|
||||
|
||||
if _without_encoder_info_keys(fa_key) != _without_encoder_info_keys(fb_key):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
@@ -242,6 +291,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return validate_feature_language(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
@@ -321,6 +372,30 @@ def validate_feature_string(name: str, value: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def validate_feature_language(name: str, value) -> str:
|
||||
"""Validate a feature that is expected to hold language annotations.
|
||||
|
||||
Language columns (``language_persistent`` / ``language_events``) are
|
||||
populated after recording by the annotation pipeline, not at record time.
|
||||
Any value supplied here is dropped before the frame is written, so a
|
||||
non-empty value almost certainly signals a mistake. We warn rather than
|
||||
fail to keep recording resilient.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
str: Always an empty string — language values are non-fatal.
|
||||
"""
|
||||
if value is not None:
|
||||
logging.warning(
|
||||
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
|
||||
f"not at record time. The provided value will be dropped."
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -186,14 +186,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -265,11 +257,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -304,8 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
242
src/lerobot/datasets/language.py
Normal file
242
src/lerobot/datasets/language.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
|
||||
uses for frame data.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float32(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float32"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
545
src/lerobot/datasets/language_render.py
Normal file
545
src/lerobot/datasets/language_render.py
Normal file
@@ -0,0 +1,545 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
|
||||
is also a float32 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz). This does mean two persistent
|
||||
rows of the same selector emitted within 0.1 s of each other cannot be
|
||||
told apart by ``emitted_at`` — acceptable because persistent annotations
|
||||
(subtask / plan / memory transitions) change on a human-action timescale,
|
||||
not at the camera frame rate."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||
we use a tolerance instead of bit-equality). For event styles, the
|
||||
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||
(event rows carry no timestamp of their own), so all matching event rows
|
||||
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||
``camera`` field — required to disambiguate when multiple view-dependent
|
||||
rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||
]
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the resolver is ambiguous.
|
||||
|
||||
Multiple matches always raise — even when the caller already passed
|
||||
some selectors — because remaining ambiguity means the data has
|
||||
several rows that look identical to the resolver and the caller
|
||||
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
|
||||
rows shared across cameras).
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r} role={role!r} "
|
||||
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
|
||||
f"Add a selector that distinguishes them."
|
||||
)
|
||||
return rows[0]
|
||||
|
||||
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
return float(unwrap_scalar(row["timestamp"]))
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -24,6 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -36,8 +37,7 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
get_safe_default_codec,
|
||||
resolve_vcodec,
|
||||
get_safe_default_video_backend,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
episode_filter: Callable[[dict], bool] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
@@ -58,10 +59,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
|
||||
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
|
||||
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
|
||||
Defaults to None.
|
||||
image_transforms (Callable | None, optional):
|
||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||
conversion. This works for both image-backed and video-backed observations and can later be
|
||||
@@ -177,16 +183,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
|
||||
Note:
|
||||
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
|
||||
@@ -199,13 +204,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.reader = None
|
||||
self.set_image_transforms(image_transforms)
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -218,6 +221,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
|
||||
if episodes is not None and any(
|
||||
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||
):
|
||||
logger.warning(
|
||||
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
|
||||
)
|
||||
|
||||
if episode_filter is not None:
|
||||
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
|
||||
)
|
||||
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
||||
episodes = resolved
|
||||
self.episodes = episodes
|
||||
|
||||
# Create reader (hf_dataset loaded below)
|
||||
self.reader = DatasetReader(
|
||||
meta=self.meta,
|
||||
@@ -251,12 +271,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
self.writer = DatasetWriter(
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
vcodec=self._vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -298,17 +321,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@staticmethod
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
camera_encoder=camera_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -625,7 +644,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -656,20 +675,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend (used when reading back).
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
|
||||
``'h264'``, ``'hevc'``, ``'auto'``.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
before flushing to parquet.
|
||||
streaming_encoding: If ``True``, encode video frames in real-time
|
||||
during capture instead of writing images first.
|
||||
encoder_queue_maxsize: Max buffered frames per camera when using
|
||||
streaming encoding.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDataset` in write mode.
|
||||
"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -690,23 +709,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -729,12 +748,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Resume recording on an existing dataset.
|
||||
|
||||
@@ -757,13 +776,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend for reading back data.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
vcodec: Video codec for encoding.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
image_writer_threads: Threads for async image writing.
|
||||
streaming_encoding: If ``True``, encode video in real-time during
|
||||
capture.
|
||||
encoder_queue_maxsize: Max buffered frames per camera for streaming.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A :class:`LeRobotDataset` in write mode, ready to append episodes.
|
||||
@@ -774,7 +795,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
|
||||
"the shared cache. Please provide a local directory path."
|
||||
)
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj._requested_root = Path(root)
|
||||
@@ -783,11 +803,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if obj._requested_root is not None:
|
||||
obj._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -796,21 +814,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata(
|
||||
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer for appending
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
174
src/lerobot/datasets/pyav_utils.py
Normal file
174
src/lerobot/datasets/pyav_utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
|
||||
|
||||
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
|
||||
Checks degrade to a no-op when the target codec isn't available locally.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
try:
|
||||
return av.codec.Codec(vcodec, "w")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return {}
|
||||
return {opt.name: opt for opt in codec.descriptor.options}
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
|
||||
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return ()
|
||||
return tuple(fmt.name for fmt in (codec.video_formats or []))
|
||||
|
||||
|
||||
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
|
||||
|
||||
Each name is probed directly via :func:`get_codec`; input order is preserved.
|
||||
"""
|
||||
if isinstance(encoders, str):
|
||||
encoders = [encoders]
|
||||
|
||||
available: list[str] = []
|
||||
for name in encoders:
|
||||
codec = get_codec(name)
|
||||
if codec is not None and codec.type == "video":
|
||||
available.append(name)
|
||||
else:
|
||||
logger.debug("encoder '%s' not available as video encoder", name)
|
||||
return available
|
||||
|
||||
|
||||
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
|
||||
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
|
||||
type_name = opt.type.name
|
||||
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
num_val = float(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
|
||||
# Check integer type compatibility
|
||||
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
|
||||
raise ValueError(
|
||||
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
|
||||
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
|
||||
)
|
||||
|
||||
# Check numeric range compatibility
|
||||
lo, hi = float(opt.min), float(opt.max)
|
||||
if lo < hi and not (lo <= num_val <= hi):
|
||||
raise ValueError(
|
||||
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
|
||||
)
|
||||
|
||||
elif type_name == "STRING":
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
|
||||
if isinstance(value, str):
|
||||
str_val = value
|
||||
elif isinstance(value, (int, float)):
|
||||
str_val = str(value)
|
||||
else:
|
||||
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
|
||||
|
||||
# Check string choice compatibility
|
||||
choices = [c.name for c in (opt.choices or [])]
|
||||
if choices and str_val not in choices:
|
||||
raise ValueError(
|
||||
f"{label}={str_val!r} is not a supported choice for codec "
|
||||
f"{vcodec!r}; valid choices: {choices}"
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
formats = _get_codec_video_formats(vcodec)
|
||||
if formats and pix_fmt not in formats:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
|
||||
f"supported pixel formats: {list(formats)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
for key, value in codec_options.items():
|
||||
# GOP size is not a codec-specific option, it has to be validated separately.
|
||||
if key == "g":
|
||||
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
|
||||
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
|
||||
continue
|
||||
if key not in supported_options:
|
||||
continue
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
ValueError: on the first incompatibility encountered.
|
||||
"""
|
||||
options = _get_codec_options_by_name(vcodec)
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
@@ -88,7 +88,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -130,6 +129,9 @@ class DatasetInfo:
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||
tools: list[dict] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
@@ -151,11 +153,15 @@ class DatasetInfo:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
Drops ``tools`` when unset so existing datasets keep a clean
|
||||
``info.json``.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
if d.get("tools") is None:
|
||||
d.pop("tools", None)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -17,12 +17,14 @@ import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@@ -33,90 +35,17 @@ import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.configs import (
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
@@ -132,7 +61,9 @@ def decode_video_frames(
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
|
||||
@@ -142,88 +73,90 @@ def decode_video_frames(
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
|
||||
)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
def decode_video_frames_pyav(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
The backend can be either "pyav" (default) or "video_reader".
|
||||
"video_reader" requires installing torchvision from source, see:
|
||||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
(note that you need to compile against ffmpeg<4.3)
|
||||
This is the fallback decoder for platforms where torchcodec has no wheel (currently macOS
|
||||
x86_64 and linux armv7l — see the torchcodec block in pyproject.toml for the full matrix).
|
||||
On supported platforms, prefer `decode_video_frames_torchcodec`, which is faster and supports
|
||||
accurate seek.
|
||||
|
||||
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
||||
For more info on video decoding, see `benchmark/video/README.md`
|
||||
PyAV doesn't support accurate seek: we seek to the nearest preceding keyframe and decode
|
||||
forward until we have covered the requested timestamp range. The number of key frames in a
|
||||
video can be adjusted at encoding time to trade off decoding speed against file size.
|
||||
|
||||
See torchvision doc for more info on these two backends:
|
||||
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
timestamps: List of timestamps (in seconds) to extract frames for.
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||
[0, 1] range.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
Returns:
|
||||
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||
"""
|
||||
video_path = str(video_path)
|
||||
|
||||
# set backend
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesn't support accurate seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
reader = torchvision.io.VideoReader(video_path, "video")
|
||||
video_path = str(video_path)
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
|
||||
# load all frames until last requested frame
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
# Seek + decode. `container.seek(offset)` with no `stream` argument expects the offset in
|
||||
# av.time_base units (microseconds). `backward=True` lands us on the nearest keyframe at or
|
||||
# before `first_ts`, so we can then decode forward until we cover `last_ts`. See:
|
||||
# https://pyav.basswood-io.com/docs/stable/api/container.html#av.container.InputContainer.seek
|
||||
with av.open(video_path) as container:
|
||||
stream = container.streams.video[0]
|
||||
container.seek(int(first_ts * av.time_base), backward=True)
|
||||
|
||||
if backend == "pyav":
|
||||
reader.container.close()
|
||||
for frame in container.decode(stream):
|
||||
if frame.pts is None:
|
||||
continue
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
reader = None
|
||||
if not loaded_frames:
|
||||
raise FrameTimestampError(
|
||||
f"No frames could be decoded from {video_path} in the timestamp range [{first_ts}, {last_ts}]."
|
||||
)
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
loaded_ts_t = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and timestamps of all loaded frames
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
@@ -234,14 +167,14 @@ def decode_video_frames_torchvision(
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts_t}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
f"\nbackend: pyav"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
closest_ts = loaded_ts_t[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"{closest_ts=}")
|
||||
@@ -260,15 +193,70 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||
DEFAULT_DECODER_CACHE_SIZE = 100
|
||||
"""Default LRU capacity for :class:`VideoDecoderCache`.
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||
Sized to comfortably hold a small rolling window of episodes worth of decoders
|
||||
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
|
||||
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
|
||||
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
|
||||
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
|
||||
to the constructor (``None`` restores the legacy unbounded behaviour).
|
||||
"""
|
||||
|
||||
|
||||
def _default_max_cache_size() -> int | None:
|
||||
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
|
||||
if raw is None:
|
||||
return DEFAULT_DECODER_CACHE_SIZE
|
||||
raw = raw.strip().lower()
|
||||
if raw in ("", "none", "unbounded", "-1"):
|
||||
return None
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
|
||||
) from e
|
||||
if value <= 0:
|
||||
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
|
||||
return value
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
|
||||
|
||||
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
|
||||
backing it. When the cache is full and a new path is requested, the
|
||||
least-recently-used entry is evicted and its file handle is closed. This
|
||||
bounds host-RAM growth when iterating over datasets with many distinct
|
||||
video files (otherwise each ``DataLoader`` worker pins every decoder it has
|
||||
ever opened until the process exits).
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of decoders to retain. ``None`` disables
|
||||
eviction and restores legacy unbounded behaviour. Defaults to the
|
||||
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
|
||||
:data:`DEFAULT_DECODER_CACHE_SIZE`.
|
||||
"""
|
||||
|
||||
_SENTINEL: ClassVar[object] = object()
|
||||
|
||||
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||
if max_size is VideoDecoderCache._SENTINEL:
|
||||
max_size = _default_max_cache_size()
|
||||
if max_size is not None and max_size <= 0:
|
||||
raise ValueError(f"max_size must be positive or None; got {max_size}")
|
||||
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
|
||||
def __contains__(self, video_path: object) -> bool:
|
||||
with self._lock:
|
||||
return str(video_path) in self._cache
|
||||
|
||||
def get_decoder(self, video_path: str):
|
||||
"""Get a cached decoder or create a new one."""
|
||||
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
@@ -280,18 +268,36 @@ class VideoDecoderCache:
|
||||
video_path = str(video_path)
|
||||
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
entry = self._cache.get(video_path)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
return self._cache[video_path][0]
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
# Evict LRU entries until we are back under the cap. We close
|
||||
# evicted file handles immediately; the associated ``VideoDecoder``
|
||||
# is released to the GC when its last reference goes away.
|
||||
if self.max_size is not None:
|
||||
while len(self._cache) > self.max_size:
|
||||
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||
with contextlib.suppress(Exception):
|
||||
evicted_handle.close()
|
||||
|
||||
return decoder
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache and close file handles."""
|
||||
"""Clear the cache and close all file handles."""
|
||||
with self._lock:
|
||||
for _, file_handle in self._cache.values():
|
||||
file_handle.close()
|
||||
with contextlib.suppress(Exception):
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
@@ -400,18 +406,17 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -422,42 +427,18 @@ def encode_video_frames(
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -493,8 +474,97 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Re-encode a video file using the given encoder configuration.
|
||||
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping re-encode.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
if log_level is not None:
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
try:
|
||||
with av.open(input_video_path, mode="r") as src:
|
||||
try:
|
||||
in_stream = src.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise ValueError(f"No video stream in {input_video_path}") from e
|
||||
|
||||
fps = (
|
||||
in_stream.base_rate
|
||||
) # We allow fractional fps though LeRobotDataset only supports integer fps
|
||||
width = int(in_stream.width)
|
||||
height = int(in_stream.height)
|
||||
|
||||
with av.open(
|
||||
tmp_output_video_path,
|
||||
mode="w",
|
||||
options={
|
||||
"movflags": "faststart"
|
||||
}, # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
) as dst:
|
||||
out_stream = dst.add_stream(vcodec, fps, options=video_options)
|
||||
out_stream.pix_fmt = pix_fmt
|
||||
out_stream.width = width
|
||||
out_stream.height = height
|
||||
|
||||
for frame in src.decode(in_stream):
|
||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||
packet = out_stream.encode(frame)
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
packet = out_stream.encode()
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
except Exception:
|
||||
Path(tmp_output_video_path).unlink(missing_ok=True)
|
||||
raise
|
||||
finally:
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_video_path.exists():
|
||||
raise OSError(f"Video re-encoding did not work. File not found: {output_video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
input_video_paths: list[Path | str],
|
||||
output_video_path: Path,
|
||||
overwrite: bool = True,
|
||||
compatibility_check: bool = False,
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
@@ -507,6 +577,7 @@ def concatenate_video_files(
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
compatibility_check: Whether to check if the input videos are compatible. Default is False.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
@@ -525,6 +596,22 @@ def concatenate_video_files(
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# This check may be skipped at recording time as videos are encoded with the same encoder config.
|
||||
if compatibility_check:
|
||||
reference_video_info = get_video_info(input_video_paths[0])
|
||||
for input_path in input_video_paths[1:]:
|
||||
video_info = get_video_info(input_path)
|
||||
if (
|
||||
video_info["video.height"] != reference_video_info["video.height"]
|
||||
or video_info["video.width"] != reference_video_info["video.width"]
|
||||
or video_info["video.fps"] != reference_video_info["video.fps"]
|
||||
or video_info["video.codec"] != reference_video_info["video.codec"]
|
||||
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
|
||||
)
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
@@ -591,26 +678,20 @@ class _CameraEncoderThread(threading.Thread):
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
codec_options: dict[str, str],
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.codec_options = codec_options
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -646,19 +727,9 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
@@ -724,22 +795,24 @@ class StreamingVideoEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
When ``None``, :func:`camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
"""
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
@@ -770,18 +843,17 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -986,8 +1058,18 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
@@ -1018,6 +1100,14 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
|
||||
@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
_placo_runtime_error: ImportError | None = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
try:
|
||||
import placo # type: ignore[import-not-found]
|
||||
except ImportError as _placo_import_err:
|
||||
placo = None
|
||||
_placo_runtime_error = _placo_import_err
|
||||
|
||||
|
||||
def _raise_if_placo_unusable() -> None:
|
||||
if placo is None and _placo_runtime_error is not None:
|
||||
raise ImportError(
|
||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
||||
) from _placo_runtime_error
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
@@ -44,6 +57,7 @@ class RobotKinematics:
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
require_package("placo", extra="placo-dep")
|
||||
_raise_if_placo_unusable()
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -43,6 +43,7 @@ from .tables import (
|
||||
CAN_CMD_SET_ZERO,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
HANDSHAKE_TIMEOUT_S,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
||||
def _query_status_via_clear_fault(
|
||||
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
||||
) -> tuple[bool, can.Message | None]:
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
||||
|
||||
def _recv_status_via_clear_fault(
|
||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
faulted_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
elif has_fault:
|
||||
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
|
||||
return responses
|
||||
|
||||
def _recv_all_messages_until_quiet(
|
||||
self,
|
||||
*,
|
||||
timeout: float = RUNNING_TIMEOUT,
|
||||
max_messages: int = 4096,
|
||||
) -> list[can.Message]:
|
||||
"""
|
||||
Receive frames until the bus goes quiet.
|
||||
|
||||
Args:
|
||||
timeout: Poll timeout used for each recv() call. Collection stops
|
||||
when one recv() times out (quiet gap).
|
||||
max_messages: Safety cap to prevent unbounded loops.
|
||||
"""
|
||||
out: list[can.Message] = []
|
||||
max_messages = max(1, max_messages)
|
||||
timeout = max(0.0, timeout)
|
||||
|
||||
try:
|
||||
while len(out) < max_messages:
|
||||
msg = self._bus().recv(timeout=timeout)
|
||||
if msg is None:
|
||||
break
|
||||
out.append(msg)
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
||||
|
||||
return out
|
||||
|
||||
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
||||
"""
|
||||
Decode all received feedback frames and update cached motor states.
|
||||
|
||||
Returns:
|
||||
Set of payload recv_ids that were successfully mapped to motors.
|
||||
"""
|
||||
processed_recv_ids: set[int] = set()
|
||||
for msg in messages:
|
||||
if len(msg.data) < 1:
|
||||
logger.debug(
|
||||
f"Dropping short CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
recv_id = int(msg.data[0])
|
||||
motor_name = self._recv_id_to_motor.get(recv_id)
|
||||
if motor_name is None:
|
||||
logger.debug(
|
||||
f"Unmapped CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
self._process_response(motor_name, msg)
|
||||
processed_recv_ids.add(recv_id)
|
||||
|
||||
return processed_recv_ids
|
||||
|
||||
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
||||
"""
|
||||
Drain pending RX frames from the CAN interface.
|
||||
|
||||
This is used by higher-level controllers to drop stale feedback before issuing
|
||||
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
||||
It should also be called once when a controller instance is created/connected,
|
||||
to clear residual frames left on the interface from previous sessions.
|
||||
"""
|
||||
drained = 0
|
||||
poll_timeout_s = max(0.0, poll_timeout_s)
|
||||
max_messages = max(1, max_messages)
|
||||
try:
|
||||
while drained < max_messages:
|
||||
msg = self._bus().recv(timeout=poll_timeout_s)
|
||||
if msg is None:
|
||||
break
|
||||
drained += 1
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
||||
return drained
|
||||
|
||||
def _speed_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
# Read every feedback frame until RX goes quiet, then decode all of them.
|
||||
# This avoids dropping useful frames when responses from different motors interleave.
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
try:
|
||||
self._decode_motor_state(msg.data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to decode response from {motor} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
||||
)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the state cache."""
|
||||
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._bus().send(msg)
|
||||
updated_motors.append(motor)
|
||||
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
||||
|
||||
for response in responses.values():
|
||||
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
||||
if payload_motor_name is not None:
|
||||
self._process_response(payload_motor_name, response)
|
||||
else:
|
||||
# Fallback: still attempt to decode based on payload byte0 mapping.
|
||||
self._decode_motor_state(response.data)
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
for motor in updated_motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if recv_id not in responses:
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
|
||||
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
RUNNING_TIMEOUT = 0.003
|
||||
HANDSHAKE_TIMEOUT_S = 0.05
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
|
||||
@@ -16,14 +16,15 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
@@ -31,20 +32,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
|
||||
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
|
||||
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
||||
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
|
||||
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"SACConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
horizon: int = 64
|
||||
n_action_steps: int = 32
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
use_separate_rgb_encoder_per_camera: bool = True
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
|
||||
1
src/lerobot/policies/eo1/README.md
Symbolic link
1
src/lerobot/policies/eo1/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/eo1.mdx
|
||||
7
src/lerobot/policies/eo1/__init__.py
Normal file
7
src/lerobot/policies/eo1/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
from .modeling_eo1 import EO1Policy
|
||||
from .processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||
193
src/lerobot/policies/eo1/configuration_eo1.py
Normal file
193
src/lerobot/policies/eo1/configuration_eo1.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLTextConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
else:
|
||||
Qwen2_5_VLConfig = None
|
||||
Qwen2_5_VLTextConfig = None
|
||||
Qwen2_5_VLVisionConfig = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("eo1")
|
||||
@dataclass
|
||||
class EO1Config(PreTrainedConfig):
|
||||
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||
|
||||
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
vlm_config: dict | None = None
|
||||
|
||||
# Vision processor settings.
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
# Execution and action horizon.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 8
|
||||
n_action_steps: int = 8
|
||||
|
||||
# State/action padding to match EO1 flow head dimensionality.
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching sampling.
|
||||
num_denoise_steps: int = 10
|
||||
num_action_layers: int = 2
|
||||
action_act: str = "linear"
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
supervise_padding_action_dims: bool = True
|
||||
supervise_padding_actions: bool = True
|
||||
|
||||
# Policy-level dtype request for the Qwen backbone.
|
||||
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||
force_fp32_autocast: bool = True
|
||||
|
||||
# Optional attention backend request passed through to the Qwen backbone.
|
||||
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||
attn_implementation: str | None = None
|
||||
|
||||
# Training settings.
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.1
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# Populate the serialized backbone config only when the caller did not provide one.
|
||||
if self.vlm_config is None:
|
||||
require_package("transformers", extra="eo1")
|
||||
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||
require_package("transformers", extra="eo1")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
if self.attn_implementation is not None:
|
||||
config_dict["attn_implementation"] = self.attn_implementation
|
||||
return Qwen2_5_VLConfig(**config_dict)
|
||||
|
||||
@property
|
||||
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||
return self.vlm_backbone_config.text_config
|
||||
|
||||
@property
|
||||
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||
return self.vlm_backbone_config.vision_config
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up EO1 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"EO1 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
621
src/lerobot/policies/eo1/modeling_eo1.py
Normal file
621
src/lerobot/policies/eo1/modeling_eo1.py
Normal file
@@ -0,0 +1,621 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_eo1 import EO1Config
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils import torch_compilable_check
|
||||
else:
|
||||
ACT2FN = None
|
||||
Qwen2_5_VLForConditionalGeneration = None
|
||||
torch_compilable_check = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] >= new_dim:
|
||||
return vector
|
||||
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||
|
||||
|
||||
class EO1Policy(PreTrainedPolicy):
|
||||
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||
|
||||
config_class = EO1Config
|
||||
name = "eo1"
|
||||
|
||||
def __init__(self, config: EO1Config, **kwargs):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.pretrained_path is None:
|
||||
# Initialize from pretrained VLM
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config.vlm_base,
|
||||
dtype=config.dtype,
|
||||
attn_implementation=config.attn_implementation,
|
||||
)
|
||||
else:
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||
)
|
||||
|
||||
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
state = self.prepare_state(batch[OBS_STATE])
|
||||
actions = self.prepare_action(batch[ACTION])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||
loss = self.model(states=state, action=actions, **model_inputs)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
states = self.prepare_state(batch[OBS_STATE])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
return actions[:, :, :original_action_dim]
|
||||
|
||||
def prepare_state(self, state: Tensor) -> Tensor:
|
||||
return pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
def prepare_action(self, action: Tensor) -> Tensor:
|
||||
return pad_vector(action, self.config.max_action_dim)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 2,
|
||||
activation_layer: str = "linear",
|
||||
bias: bool = True,
|
||||
device: Any = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
layers = []
|
||||
in_dim = in_channels
|
||||
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||
for hidden_dim in hidden_channels[:-1]:
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||
layers.append(ACT2FN[activation_layer])
|
||||
in_dim = hidden_dim
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self[0].weight.dtype
|
||||
|
||||
|
||||
class EO1VisionFlowMatchingModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: EO1Config,
|
||||
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||
):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||
self.vlm_backbone = vlm_backbone
|
||||
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_out_proj = EO1VisionActionProjector(
|
||||
self.hidden_size,
|
||||
max_action_dim,
|
||||
config.num_action_layers,
|
||||
config.action_act,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.vlm_backbone.get_input_embeddings()
|
||||
|
||||
def flow_head_autocast_context(self):
|
||||
if self.config.force_fp32_autocast:
|
||||
return torch.autocast(
|
||||
device_type=self.state_proj.weight.device.type,
|
||||
enabled=False,
|
||||
)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.vlm_backbone.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.vlm_backbone.gradient_checkpointing_disable()
|
||||
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(
|
||||
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||
)
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def get_placeholder_mask(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None,
|
||||
inputs_embeds: torch.FloatTensor | None,
|
||||
state_features: torch.FloatTensor | None = None,
|
||||
action_features: torch.FloatTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||
if input_ids is None:
|
||||
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_state_mask = special_state_mask.all(-1)
|
||||
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_action_mask = special_action_mask.all(-1)
|
||||
else:
|
||||
special_state_mask = input_ids == state_token_id
|
||||
special_action_mask = input_ids == action_token_id
|
||||
|
||||
n_state_tokens = special_state_mask.sum()
|
||||
special_state_mask = (
|
||||
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if state_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||
)
|
||||
|
||||
n_action_tokens = special_action_mask.sum()
|
||||
special_action_mask = (
|
||||
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if action_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||
)
|
||||
|
||||
return special_state_mask, special_action_mask
|
||||
|
||||
def embed_prefix(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
states: torch.Tensor,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||
|
||||
# Get the input embeddings for the input IDs
|
||||
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
return self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||
|
||||
# Project the states to the hidden size
|
||||
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||
return self.state_proj(states)
|
||||
|
||||
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||
state_mask, _ = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_features=state_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||
return inputs_embeds
|
||||
|
||||
def embed_suffix(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
noisy_actions: torch.Tensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the suffix"""
|
||||
|
||||
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
time_embs = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.hidden_size,
|
||||
min_period=self.config.min_period,
|
||||
max_period=self.config.max_period,
|
||||
device=action_embs.device,
|
||||
)
|
||||
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||
|
||||
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||
action_time_embs = F.silu(action_time_embs)
|
||||
return self.action_time_mlp_out(action_time_embs)
|
||||
|
||||
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||
return action_time_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.FloatTensor | None = None,
|
||||
action: torch.FloatTensor | None = None,
|
||||
action_is_pad: torch.BoolTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||
|
||||
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
|
||||
# 2. Sample the diffusion target and replace the action placeholders.
|
||||
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||
u_t = noise - action
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
_, action_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
action_features=action_time_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||
|
||||
# 3. Optionally drop padded action tokens from backbone attention.
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||
action_token_mask = action_mask[..., 0]
|
||||
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||
action_padding_mask = action_padding_mask.masked_scatter(
|
||||
action_token_mask,
|
||||
action_is_pad.reshape(-1),
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||
|
||||
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||
def vlm_forward_func(
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
pixel_values: torch.Tensor | None,
|
||||
image_grid_thw: torch.LongTensor | None,
|
||||
mm_token_type_ids: torch.IntTensor | None,
|
||||
) -> torch.FloatTensor:
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
hidden_states = self._apply_checkpoint(
|
||||
vlm_forward_func,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
mm_token_type_ids,
|
||||
)
|
||||
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||
|
||||
# 5. Project the action-token hidden states back to the flow target space.
|
||||
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
return self.action_out_proj(action_hidden_states)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
# 6. Apply the configured supervision mask and reduce the loss.
|
||||
if not self.config.supervise_padding_action_dims:
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[..., :original_action_dim]
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
losses = losses[~action_is_pad]
|
||||
|
||||
return losses.mean()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.Tensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Sample actions from the model."""
|
||||
if states is None:
|
||||
raise ValueError("states are required for EO1 action sampling.")
|
||||
if mm_token_type_ids is None:
|
||||
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||
|
||||
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||
chunk_size = self.config.chunk_size
|
||||
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
).clone()
|
||||
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_mask = action_placeholder_mask[..., 0]
|
||||
token_counts = action_mask.sum(dim=1)
|
||||
if not torch.all(token_counts == chunk_size):
|
||||
raise ValueError(
|
||||
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||
)
|
||||
if action_mask.ne(action_mask[:1]).any():
|
||||
raise ValueError(
|
||||
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||
)
|
||||
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||
act_end = act_start + self.config.chunk_size
|
||||
if not torch.all(action_mask[:, act_start:act_end]):
|
||||
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||
act_slice = slice(act_start, act_end)
|
||||
|
||||
# 2. Encode the fixed prefix once and cache its KV state.
|
||||
batch_size = input_ids.shape[0]
|
||||
device = inputs_embeds.device
|
||||
attention_mask = attention_mask.to(device)
|
||||
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
)
|
||||
position_ids = position_ids.to(device)
|
||||
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids[:, :act_start],
|
||||
attention_mask=attention_mask[:, :act_start],
|
||||
position_ids=position_ids[..., :act_start],
|
||||
inputs_embeds=inputs_embeds[:, :act_start],
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
x_t = self.sample_noise(
|
||||
(batch_size, chunk_size, self.config.max_action_dim),
|
||||
device,
|
||||
).to(dtype=self.action_in_proj.weight.dtype)
|
||||
dt = -1.0 / self.config.num_denoise_steps
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||
for step in range(self.config.num_denoise_steps):
|
||||
time = torch.full(
|
||||
(batch_size,),
|
||||
1.0 + step * dt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||
|
||||
# Keep the prefix KV cache invariant across denoising steps.
|
||||
past_key_values.crop(act_start)
|
||||
outputs = self.vlm_backbone.model(
|
||||
attention_mask=attention_mask[:, :act_end],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds[:, act_slice],
|
||||
position_ids=position_ids[..., act_slice],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
with self.flow_head_autocast_context():
|
||||
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
v_t = self.action_out_proj(hidden_states)
|
||||
|
||||
x_t += dt * v_t.reshape(x_t.shape)
|
||||
|
||||
return x_t
|
||||
283
src/lerobot/policies/eo1/processor_eo1.py
Normal file
283
src/lerobot/policies/eo1/processor_eo1.py
Normal file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
Qwen2_5_VLProcessor = None
|
||||
|
||||
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||
|
||||
# EO-1 special tokens
|
||||
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||
|
||||
EO1_SPECIAL_TOKENS = [
|
||||
ACTION_START_TOKEN,
|
||||
DEFAULT_ACTION_TOKEN,
|
||||
ACTION_END_TOKEN,
|
||||
STATE_START_TOKEN,
|
||||
DEFAULT_STATE_TOKEN,
|
||||
STATE_END_TOKEN,
|
||||
TASK_VLA_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||
chunk_size: int
|
||||
|
||||
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.input_features:
|
||||
first_val = next(iter(self.input_features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.input_features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.input_features = reconstructed
|
||||
|
||||
self._image_keys = [
|
||||
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||
]
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
tasks = complementary_data.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||
|
||||
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||
|
||||
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||
images = {
|
||||
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||
}
|
||||
messages = []
|
||||
for i in range(len(tasks)):
|
||||
content = [
|
||||
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||
),
|
||||
},
|
||||
]
|
||||
messages.append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||
{"role": "user", "content": content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
complementary_data["messages"] = messages
|
||||
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only materializes EO1-specific message objects in complementary_data.
|
||||
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||
feature contract change to record here.
|
||||
"""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||
},
|
||||
"chunk_size": self.chunk_size,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("transformers", extra="eo1")
|
||||
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
self.processor_name,
|
||||
use_fast=self.use_fast_processor,
|
||||
)
|
||||
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
messages = complementary_data.pop("messages", None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||
|
||||
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||
# Supervised batches use right padding to match standard training collation.
|
||||
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||
|
||||
inputs = self._processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
padding_side=padding_side,
|
||||
min_pixels=self.image_min_pixels,
|
||||
max_pixels=self.image_max_pixels,
|
||||
add_generation_prompt=False,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
complementary_data["input_ids"] = inputs["input_ids"]
|
||||
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||
complementary_data["state_token_id"] = self._state_token_id
|
||||
complementary_data["action_token_id"] = self._action_token_id
|
||||
|
||||
return complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"processor_name": self.processor_name,
|
||||
"image_min_pixels": self.image_min_pixels,
|
||||
"image_max_pixels": self.image_max_pixels,
|
||||
"use_fast_processor": self.use_fast_processor,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only converts the messages to the model input format.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_eo1_pre_post_processors(
|
||||
config: EO1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build pre/post processor pipelines for EO1."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||
EO1QwenProcessorStep(
|
||||
processor_name=config.vlm_base,
|
||||
image_min_pixels=config.image_min_pixels,
|
||||
image_max_pixels=config.image_max_pixels,
|
||||
use_fast_processor=config.use_fast_processor,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -46,12 +46,13 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
@@ -87,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -126,10 +127,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "sac":
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
return SACPolicy
|
||||
return GaussianActorPolicy
|
||||
elif name == "smolvla":
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
@@ -146,6 +147,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
elif name == "eo1":
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -162,7 +167,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
@@ -186,8 +191,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
@@ -196,6 +201,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -358,10 +365,10 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from .sac.processor_sac import make_sac_pre_post_processors
|
||||
elif isinstance(policy_cfg, GaussianActorConfig):
|
||||
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
processors = make_gaussian_actor_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
@@ -399,6 +406,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
processors = make_eo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .modeling_sac import SACPolicy
|
||||
from .processor_sac import make_sac_pre_post_processors
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
from .modeling_gaussian_actor import GaussianActorPolicy
|
||||
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
|
||||
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
|
||||
@@ -1,4 +1,4 @@
|
||||
# !/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
@@ -75,18 +75,19 @@ class PolicyConfig:
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@PreTrainedConfig.register_subclass("gaussian_actor")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
class GaussianActorConfig(PreTrainedConfig):
|
||||
"""Gaussian actor configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
This configures the policy-side (actor + observation encoder) of a Gaussian
|
||||
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
|
||||
By default the actor output is a tanh-squashed diagonal Gaussian
|
||||
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
|
||||
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
|
||||
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
CLI: ``--policy.type=gaussian_actor``.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
@@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig):
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Encoder architecture
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
|
||||
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Capacity of the online replay buffer
|
||||
@@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig):
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
# Network architecture
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters (Gaussian head)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
# Any validation specific to GaussianActor configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
# Default learning rate used to satisfy the abstract ``get_optimizer_preset()``
|
||||
# contract from ``PreTrainedConfig``. The actual optimizers used during RL
|
||||
# training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from
|
||||
# ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass
|
||||
# this preset.
|
||||
default_lr = 3e-4
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
"actor": {"lr": default_lr},
|
||||
"critic": {"lr": default_lr},
|
||||
"temperature": {"lr": default_lr},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -15,16 +15,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_sac import SACConfig, is_image_feature
|
||||
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
class GaussianActorPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
config_class = GaussianActorConfig
|
||||
name = "gaussian_actor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
config: GaussianActorConfig | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -54,9 +49,8 @@ class SACPolicy(
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
self._init_discrete_critic()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
@@ -65,11 +59,7 @@ class SACPolicy(
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
@@ -79,7 +69,9 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
raise NotImplementedError(
|
||||
"GaussianActorPolicy does not support action chunking. It returns single actions!"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -92,360 +84,43 @@ class SACPolicy(
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
if self.discrete_critic is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
else:
|
||||
discrete_action = torch.ones(
|
||||
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
|
||||
)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
|
||||
"""Actor forward pass: sample actions and return log-probabilities.
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
batch: A flat observation dict, or a training dict containing
|
||||
``"state"`` (observations) and optionally ``"observation_feature"``
|
||||
(pre-computed encoder features).
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_critic = GaussianActorObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
"""Initialize policy actor network."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
@@ -455,21 +130,25 @@ class SACPolicy(
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
def _init_discrete_critic(self) -> None:
|
||||
"""Initialize discrete critic network."""
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
class GaussianActorObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
def __init__(self, config: GaussianActorConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._init_image_layers()
|
||||
@@ -677,84 +356,6 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class DiscreteCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module):
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_min: float = -5,
|
||||
@@ -811,7 +412,7 @@ class Policy(nn.Module):
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder: SACObservationEncoder = encoder
|
||||
self.encoder: GaussianActorObservationEncoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_min = std_min
|
||||
@@ -885,7 +486,7 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
@@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
@@ -32,18 +32,18 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
def make_gaussian_actor_pre_post_processors(
|
||||
config: GaussianActorConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SAC policy.
|
||||
config: The configuration object for the tanh-Gaussian policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
|
||||
Returns:
|
||||
@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc = AutoProcessor.from_pretrained(
|
||||
str(cache_dir),
|
||||
trust_remote_code=True,
|
||||
fix_mistral_regex=False,
|
||||
)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -258,15 +265,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -274,13 +282,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -255,15 +262,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -271,13 +279,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -441,13 +449,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -662,8 +674,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -881,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
# VQ-VAE
|
||||
n_vqvae_training_steps: int = 20000
|
||||
|
||||
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -95,6 +95,13 @@ from .relative_action_processor import (
|
||||
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
|
||||
# `lerobot.datasets.language`, which requires the `[dataset]` extra
|
||||
# (`datasets`, `pyarrow`). Importing it from the processor package would
|
||||
# break every base-install consumer of `lerobot.processor`. Users that
|
||||
# need it import directly:
|
||||
# from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
|
||||
@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
complementary_data["messages"] = [messages]
|
||||
|
||||
if "message_streams" in complementary_data:
|
||||
streams = complementary_data["message_streams"]
|
||||
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
|
||||
complementary_data["message_streams"] = [streams]
|
||||
|
||||
if "target_message_indices" in complementary_data:
|
||||
indices = complementary_data["target_message_indices"]
|
||||
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
|
||||
complementary_data["target_message_indices"] = [indices]
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -153,26 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
||||
return x
|
||||
|
||||
|
||||
_COMPLEMENTARY_KEYS = (
|
||||
"task",
|
||||
"index",
|
||||
"task_index",
|
||||
"episode_index",
|
||||
"timestamp",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
"messages",
|
||||
"message_streams",
|
||||
"target_message_indices",
|
||||
)
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
"""Extract complementary data from a batch dictionary.
|
||||
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
Includes padding flags (any key containing ``_is_pad``) plus the fixed
|
||||
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
|
||||
each only when present in ``batch``.
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
|
||||
return {**pad_keys, **extras}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
@@ -321,6 +320,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
3. Copying `discrete_penalty` from `info` to `complementary_data`.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -330,6 +330,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if DISCRETE_PENALTY_KEY in info:
|
||||
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
@@ -348,18 +351,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
Applies a small per-transition cost on the discrete gripper action.
|
||||
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
Fires only when the commanded action would actually transition the gripper
|
||||
from one extreme to the other (close-while-open or open-while-closed).
|
||||
This discourages gripper oscillation while leaving "stay" and saturating-further
|
||||
commands unpenalized.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
open_threshold: Normalized state below which the gripper is considered "open".
|
||||
closed_threshold: Normalized state above which the gripper is considered "closed".
|
||||
"""
|
||||
|
||||
penalty: float = -0.01
|
||||
penalty: float = -0.02
|
||||
max_gripper_pos: float = 30.0
|
||||
open_threshold: float = 0.1
|
||||
closed_threshold: float = 0.9
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -379,11 +388,15 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
if raw_joint_positions is None:
|
||||
return new_transition
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
current_gripper_pos = raw_joint_positions.get(f"{GRIPPER_KEY}.pos", None)
|
||||
if current_gripper_pos is None:
|
||||
return new_transition
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
# During reset, the transition may not carry any action yet.
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
# Gripper action is expected as the last action dimension.
|
||||
gripper_action = action[-1].item()
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
@@ -391,9 +404,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
# - currently open AND target is closed -> close transition
|
||||
# - currently closed AND target is open -> open transition
|
||||
is_open = gripper_state_normalized < self.open_threshold
|
||||
is_closed = gripper_state_normalized > self.closed_threshold
|
||||
cmd_close = gripper_action_normalized > self.closed_threshold
|
||||
cmd_open = gripper_action_normalized < self.open_threshold
|
||||
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
@@ -409,11 +426,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
A dictionary containing the penalty value, max gripper position,
|
||||
and the open/closed thresholds.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
"open_threshold": self.open_threshold,
|
||||
"closed_threshold": self.closed_threshold,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -134,6 +134,24 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
|
||||
|
||||
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
|
||||
(already ``(C, 1, 1)``). Needed by RL training, which can start without
|
||||
a dataset and supplies stats manually via JSON config.
|
||||
"""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type != FeatureType.VISUAL:
|
||||
continue
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
|
||||
continue
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -152,6 +170,7 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -201,6 +220,7 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -211,6 +231,7 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
84
src/lerobot/processor/render_messages_processor.py
Normal file
84
src/lerobot/processor/render_messages_processor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||
class RenderMessagesStep(ProcessorStep):
|
||||
"""Processor step that turns raw language columns into rendered chat messages.
|
||||
|
||||
Reads ``language_persistent`` and ``language_events`` from the transition's
|
||||
complementary data, renders them through ``recipe`` at the sample timestamp,
|
||||
and replaces the raw columns with the resulting ``messages`` /
|
||||
``message_streams`` / ``target_message_indices`` keys.
|
||||
"""
|
||||
|
||||
recipe: TrainingRecipe
|
||||
dataset_ctx: Any | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
"""Render messages for a single transition; return ``None`` to drop it."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
sample_idx = complementary_data.get("index", 0)
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=unwrap_scalar(timestamp),
|
||||
sample_idx=int(unwrap_scalar(sample_idx)),
|
||||
task=complementary_data.get("task"),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
@@ -30,7 +30,7 @@ class RewardClassifierConfig(RewardModelConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
model_name: str = "lerobot/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
@@ -17,10 +17,11 @@ import logging
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
|
||||
from ..pretrained import PreTrainedRewardModel
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
@@ -105,6 +106,7 @@ class Classifier(PreTrainedRewardModel):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
**kwargs,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
|
||||
@@ -22,9 +22,10 @@ import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user