Compare commits

...

47 Commits

Author SHA1 Message Date
Caroline Pascal
3ab08a5318 fix(imports): fixing av import in test_depth.py 2026-05-22 15:13:15 +02:00
CarolinePascal
5e53e6bd2f tests(typos): fixing typos in tests 2026-05-22 13:09:56 +02:00
CarolinePascal
a94d9f119c fix(info): fixing info metadata update when is_depth_map was set 2026-05-22 02:48:30 +02:00
CarolinePascal
8a615070e7 fix(pre-commit): fixing mutable defautl value 2026-05-22 02:07:33 +02:00
CarolinePascal
8e56797287 feat(refactor): refactor DepthEncoderConfig quantization pipeline, so that the methods do not live in the config class. Add pixel format - channels validation.Move the default pixel format for depth in the config file. 2026-05-22 02:06:37 +02:00
CarolinePascal
7498f1cf61 feat(pix_fmt channels): use PyAv to check get pixel formats number of channels 2026-05-22 02:03:23 +02:00
CarolinePascal
72a429764a tests(depth): adding new tests for depth integration validation 2026-05-21 20:20:40 +02:00
CarolinePascal
4ea8653ca3 test(fix): fixing exisiting tests to still work with latest features 2026-05-21 19:56:00 +02:00
CarolinePascal
eeabb4d258 chore(typos): fixing typos 2026-05-21 19:55:33 +02:00
CarolinePascal
2b8d7b3c06 fix(plumbing): fixing missing parts in the depth maps pipeline 2026-05-21 16:11:01 +02:00
CarolinePascal
4a49f4a391 fix(stop_event): fixing stop_event race condition in camera classes 2026-05-21 15:51:12 +02:00
CarolinePascal
15647f50a2 feat(is_depth): simplifying is_depth nested name + legacy support 2026-05-21 14:26:16 +02:00
CarolinePascal
e87933302d feat(depth shape): ensuring depth maps shape is always including the channel 2026-05-21 14:25:42 +02:00
CarolinePascal
3cf5e3c8cb chore(format): format code 2026-05-20 16:47:22 +02:00
CarolinePascal
33a3b5a982 feat(depth maps writer): adding support for raw depth maps recording with image writer 2026-05-20 16:42:16 +02:00
CarolinePascal
1dafb4acf6 feat(viz): render depth observations as rr.DepthImage in Viridis 2026-05-20 16:22:34 +02:00
CarolinePascal
14df709201 feat(record): plumb DepthEncoderConfig through lerobot-record 2026-05-20 16:14:14 +02:00
CarolinePascal
d6f97ae17f feat(robots/so_follower): emit + populate depth keys when use_depth 2026-05-20 16:09:53 +02:00
CarolinePascal
085f574301 feat(features): route 2D camera shapes to observation.depth.<key> 2026-05-20 15:50:46 +02:00
CarolinePascal
f15348e769 feat(cameras/realsense): expose async depth in metric meters 2026-05-20 15:24:47 +02:00
CarolinePascal
e51d45dd2c feat(depth): wire DatasetReader to decode_depth_frames 2026-05-19 23:46:28 +02:00
CarolinePascal
d39698da0f feat(depth): wire StreamingVideoEncoder + writer to depth encoder 2026-05-19 23:23:27 +02:00
CarolinePascal
b4c31f0f67 feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter 2026-05-19 22:50:19 +02:00
CarolinePascal
0cc5162078 feat(depth): extend quantization tools to better fit the encoding/decoding pipeline 2026-05-19 17:10:47 +02:00
CarolinePascal
b960524d93 feat(depth): persist depth metadata 2026-05-19 16:13:14 +02:00
CarolinePascal
088352383d feat(video): add ffv1 to supported codecs 2026-05-19 16:13:01 +02:00
CarolinePascal
42214d1c7a feat(depth): add depth quantization helpers and tests 2026-05-18 18:09:37 +02:00
Quentin Lhoest
5ebbdf3d05 Mention the new Lance LeRobotDataset implementation in the docs (#3609)
* Enhance documentation with Lance format details

Added information about Lance format and `lerobot-lancedb` package for multimodal AI datasets.

Signed-off-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
2026-05-18 14:51:26 +02:00
Khalil Meftah
6e035fb169 Update reward config and model card template (#3625) 2026-05-18 13:12:15 +02:00
Haoming Song
01dcb4c292 fix(pi05): update pi05 with transformers v5.4.0 interface (#3603) 2026-05-15 11:37:05 +02:00
Caroline Pascal
bd9619dfc3 feat(encoding parameters): adding support for user provided video encoding parameters (#3455)
* chore(video backend): renaming codec into video_backend in get_safe_default_video_backend()

* feat(pyav utils): adding suport for PyAV encoding parameters validation

* feat(VideoEncoderConfig): creating a VideoEncoderConfig to encapsulate encoding parameters

* feat(VideoEncoderConfig): propagating the VideoEncoderConfig in the codebase

* chore(docs): updating the docs

* feat(metadata): adding encoding parameters in dataset metadata

* fix(concatenation compatibility): adding compatibility check when concatenating video files

* feat(VideoEncoderConfig init): making VideoEncoderConfig more robust and adaptable to multiple backends

* feat(pyav checks): making pyav parameters checks more robust

* chore(duplicate): removing duplicate get_codec_options definition

* test(existing): adapting existing tests

* test(new): adding new tests for encoding related features

* chore(format): fixing formatting issues

* chore(PyAV): cleaning up PyAV utils and encoding parameters checks to stick to the minimun required tooling.

* chore(format): formatting code

* chore(doctrings): updating docstrings

* fix(camera_encoder_config): Removing camera_encoder_config from LeRobotDataset, as it's only required in LeRobotDatasetWriter.

* feat(default values): applying a consistent naming convention for default RGB cameras video encoder parameters

* fix(rollout): propagating VideoEncoderConfig to the latest recording modes

* chore(format): formatting code, fixing error messages and variable names

* fix(arguments order): reverting changes in arguments order in StreamingVideoEncoder

* chore(relative imports): switching to relative local imports within lerobot.datasets

* test(artifacts): cleaning up artifacts for the video encoding tests

* chore(docs): updating docs

* chore(fromat): formatting code

* fix(imports): refactoring the file architecture to avoid circular imports. VideoEncoderConfig is now defined in lerobot.configs and lazily imports av at runtime.

* fix(typos): fixing typos and small mistakes

* test(factories): updating factories

* feat(aggregate): updating dataset aggregation procedure. Encoding tuning paramters (crf, g,...) are ignored for validation and changed to None in the aggregated dataset if incompatible.

* docs(typos): fixing typos

* fix(deletion): reverting unwanted deletion

* fix(typos): fixing multiple typos

* feat(codec options): passing codec options to lerobot_edit_dataset episode deletion tool

* typo(typo): typo

* fix(typos): fixing remaining typos

* chore(rename): renaming camera_encoder_config to camera_encoder

* docs(clean): cleaning and formating docs

* docs(dataset): addind details about datasets

* chore(format): formatting code

* docs(warning): adding warning regarding encoding parameters modification

* fix(re-encoding): removing inconsistent re-encoding option in lerobot_edit_dataset

* typos(typos): typos

* chore(format): resolving prettier issues

* fix(h264_nvenc): fixing crf handling for h264_nvenc

* docs(clean): removing too technical parts of the docs

* fix(imports): fixing imports at the __init__ level

* fix(imports): fixing not very pretty imports in video config file
2026-05-14 23:46:42 +02:00
Nikodem Bartnik
0a4a7c40ad docs(cheat sheet): create cheat sheet (#3602)
* add comprehensive CLI cheat sheet for quick reference
2026-05-14 15:11:35 +02:00
Nikodem Bartnik
ca9028ad64 docs(quickstart): adding rollout (#3598)
* fix whoami command

* include lerobot-rollout in inference section
2026-05-14 12:32:39 +02:00
Cheng Yin
9db9c35cb4 fix(config): add lora_alpha to PeftConfig (#3573)
* fix(config): add lora_alpha to PeftConfig

PeftConfig was missing the lora_alpha field, causing the PEFT library
to default to alpha=8 regardless of the LoRA rank, which dampens the
adaptation signal for high-rank adapters (e.g., r=128).

This adds lora_alpha: int | None = None to PeftConfig, allowing users
to specify --peft.lora_alpha <value> on the CLI.

Closes #3551

* fix(docs): add lora_alpha to peft training example + clarify scaling formula

- Add --peft.lora_alpha=64 to docs/source/peft_training.mdx example to
  prevent new users from hitting the alpha=8 default dampening bug
- Clarify lora_alpha comment in default.py with scaling = lora_alpha / r

* docs: mention both --peft.r and --peft.lora_alpha in LoRA description

---------

Co-authored-by: Cheng Yin <yin@users.noreply.github.com>
2026-05-13 11:09:19 +02:00
Jash Shah
fe96b28c74 Fix policy.path not working in YAML config files (#3145)
* fix(config): support policy.path in YAML config files

policy.path was only handled via CLI args (filtered from sys.argv before
draccus, then retrieved in validate()). When specified in YAML, draccus
would crash because 'path' is not a valid field on PreTrainedConfig.

Extract path fields from the YAML/JSON config before draccus processes
it, store them in a module-level dict, and fall back to it in
get_path_arg() when the CLI doesn't have the path.

Fixes #2957

* fix(parser): preserve YAML policy overrides when loading from pretrained

When policy.path is set in YAML, validate() was calling from_pretrained
with only CLI overrides, discarding any YAML policy fields (e.g. lr,
batch_size) that draccus had already parsed. Fix by capturing the
remaining YAML fields as CLI-style args in _config_yaml_overrides and
merging them into the overrides passed to from_pretrained in train.py,
eval.py, and lerobot_record.py (CLI args still take precedence).

Also fix the NamedTemporaryFile SIM115 ruff warning and add types-PyYAML
to the mypy pre-commit hook.

* fix(parser): serialize bool/None values correctly in YAML policy overrides

Bool values from YAML configs (e.g. push_to_hub: true) were passed as
Python "True"/"False" strings instead of lowercase "true"/"false" that
draccus expects. Also skip None values to avoid passing "None" strings.

* revert: remove types-PyYAML from .pre-commit-config.yaml

* chore: fix quality check caused by untyped YAML import

Co-authored-by: masato-ka <jp6uzv@gmail.com>
Signed-off-by: Khalil Meftah <khalil.meftah@huggingface.co>

---------

Signed-off-by: Khalil Meftah <khalil.meftah@huggingface.co>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Co-authored-by: masato-ka <jp6uzv@gmail.com>
2026-05-13 09:45:27 +02:00
Steven Palma
2438df1307 chore(dependencies): update uv.lock (#3561) 2026-05-12 21:20:26 +02:00
Caroline Pascal
f218d5ab30 feat(episodes): adding support for metadata based episodes filtering (#3530)
* feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset

* test(tests): adding tests

* chore(format): formatting code

* feat(performance): improving implementation for better performances on big datasets

* chores(warning): improving warnings and errors for episodes filtering

* test(invalid key): adding test for invalid filtering key

* chore(format): formatting code
2026-05-12 20:44:11 +02:00
Steven Palma
04125492e4 fix(datasets): expand torchcodec platform coverage + rewrite pyav fallback for torchvision >0.26 (#3588)
* fix(deps): better versioning control for torchcodec

* refactor(video_utils): replace torchvision with pyav

* adding Torchcodec version to lerobot-info

* chore(benchmarks): delete video benchmark

---------

Co-authored-by: Maximellerbach <maxime.ellerbach@huggingface.co>
2026-05-12 16:59:11 +02:00
Khalil Meftah
e963e5a0c4 RL stack refactoring (#3075)
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring

* chore: clarify torch.compile disabled note in SACAlgorithm

* fix(teleop): keyboard EE teleop not registering special keys and losing intervention state

Fixes #2345

Co-authored-by: jpizarrom <jpizarrom@gmail.com>

* fix: remove leftover normalization calls from reward classifier predict_reward

Fixes #2355

* fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample()

* refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference

* perf: remove redundant CPU→GPU→CPU transition move in learner

* Fix: add kwargs in reward classifier __init__()

* fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer

* fix: add try/finally to control_loop to ensure image writer cleanup on exit

* fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error

* fix: skip tests that require grpc if not available

* fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests

* fix(tests): skip tests that require grpc if not available

* refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages

* fix(config): update vision encoder model name to lerobot/resnet10

* fix(sac): clarify torch.compile status

* refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity

* refactor(sac): simplify optimizer return structure

* perf(rl): use async iterators in OnlineOfflineMixer.get_iterator

* refactor(sac): decouple algorithm hyperparameters from policy config

* update losses names in tests

* fix docstring

* remove unused type alias

* fix test for flat dict structure

* refactor(policies): rename policies/sac → policies/gaussian_actor

* refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

* perf(observation_processor): add CUDA support for image processing

* fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline

(cherry picked from commit 9c2af818ff)

* fix(rl): add time limit processor to environment pipeline

(cherry picked from commit cd105f65cb)

* fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100

(cherry picked from commit 494f469a2b)

* fix(rl): update neutral gripper action

(cherry picked from commit 9c9064e5be)

* fix(rl): merge environment and action-processor info in transition processing

(cherry picked from commit 30e1886b64)

* fix(rl): mirror gym_manipulator in actor

(cherry picked from commit d2a046dfc5)

* fix(rl): postprocess action in actor

(cherry picked from commit c2556439e5)

* fix(rl): improve action processing for discrete and continuous actions

(cherry picked from commit f887ab3f6a)

* fix(rl): enhance intervention handling in actor and learner

(cherry picked from commit ef8bfffbd7)

* Revert "perf(observation_processor): add CUDA support for image processing"

This reverts commit 38b88c414c.

* refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

* refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation

* refactor(rl): add type property to RLAlgorithmConfig for better clarity

* refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility

* refactor(tests): remove grpc import checks from test files for cleaner code

* fix(tests): gate RL tests on the `datasets` extra

* refactor: simplify docstrings for clarity and conciseness across multiple files

* fix(rl): update gripper position key and handle action absence during reset

* fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset

* refactor: clean up import statements

* chore: address reviewer comments

* chore: improve visual stats reshaping logic and update docstring for clarity

* refactor: enforce mandatory config_class and name attributes in RLAlgorithm

* refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer

* refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests

* refactor: add require_package calls for grpcio and gym-hil in relevant modules

* refactor(rl): move grpcio guards to runtime entry points

* feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract
`state_dict()` / `load_state_dict()` for critic ensemble, target nets
and `log_alpha`, and persist them as a sibling `algorithm/` component
next to `pretrained_model/`. Replace the pickled `training_state.pt`
with an enriched `training_step.json` carrying `step` and
`interaction_step`, so resume restores actor + critics + target nets +
temperature + optimizers + RNG + counters from HF-standard files.

* refactor(rl): move actor weight-sync wire format from policy to algorithm

* refactor(rl): update type hints for learner and actor functions

* refactor(rl): hoist grpcio guard to module top in actor/learner

* chore(rl): manage import pattern in actor (#3564)

* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>

* update uv.lock

* chore(doc): update doc

---------

Co-authored-by: jpizarrom <jpizarrom@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-12 15:49:54 +02:00
Steven Palma
26ff40ddd7 chore(deps): cap torch ceiling at <2.12, pin Linux wheels to cu128 (#3570)
* chore(deps): ceiling + cuda

* ci: bump cuda version docker image

* ci: add cpu wheel to release workflow

* chore(deps): update uv.lock

* docs: update installation with cuda note
2026-05-11 19:47:55 +02:00
Maxime Ellerbach
6d269b28c8 docs(omx): adding some examples and scripts (#3566)
* docs(omx): adding some examples and scripts

* cleaning up and reviewing the cli args

* adding __init__.py to example folder, adjusting the examples

* adding reference to pretrained act policy

* moving `.send_action` before `dataset.add_frame` for consistency

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adjusting docstring

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adressing hardcoded dataset fps

* removed init as it worked without

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-05-11 15:36:32 +02:00
Steven Palma
b607c8458e docs: add policy & compute guide (#3534)
* docs(policy): contributing a policy guide

* docs(training): HW compute guide

* chore(docs): add to readme and index

* Apply suggestions from code review

Co-authored-by: Haoming Song <1847575517@qq.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(docs): slight improvements

* refactor(docs): consolidate add policy docs

* chore(style): fix pre-commit

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Haoming Song <1847575517@qq.com>
2026-05-11 15:19:12 +02:00
Jash Shah
9e83510c99 fix(datasets): close file handle on VideoDecoder init failure in cache (#3542)
If VideoDecoder() raises during initialization, the fsspec file handle
was leaked since it was opened via __enter__() but never closed on the
exception path. Now explicitly closes the handle before re-raising.
2026-05-10 17:30:37 +02:00
Anthony Shoumikhin
1f7b03f5f2 chore(deps): allow torch 2.11/2.12 and fix autocast deprecation (#3435)
* chore(deps): allow torch 2.11/2.12 and fix autocast deprecation

- Bump torch to >=2.7,<2.13 (was <2.11), torchvision to <0.28 (was <0.26),
  and torchcodec to <0.13 (was <0.11) to allow installs against the latest
  stable torch 2.11 and the upcoming 2.12 line.
- Replace removed torch.get_autocast_gpu_dtype() with torch.get_autocast_dtype("cuda")
  in Florence2 and Qwen2.5-VL-MoE FlashAttention paths (the former is removed in 2.11+).
- Refresh uv.lock for the new resolution (torch 2.11.0+cu130, torchvision 0.26.0+cu130,
  torchcodec 0.11.1, full CUDA 13 stack).

Verified locally with `uv sync --locked` from a clean .venv and the lerobot
test suite (pytest -n 8 --dist=loadfile --timeout=300). Failure set is
identical to the pre-bump baseline: 18 pre-existing failures
(test_sac_policy*, test_pi0_rtc*, test_pi05_rtc*, test_replay_buffer*),
0 new, 0 fixed.

AI assistance: this change was authored with Claude Code per AI_POLICY.md.

* fix(policies): use device-agnostic autocast dtype lookup

Pass query_states.device.type to torch.get_autocast_dtype() instead of
hardcoding 'cuda', so the cast matches the active autocast context when
running under CPU/MPS/XPU autocast.

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-10 13:05:35 +02:00
Steven Palma
cb8edf17e6 chore(dependencies): update uv.lock (#3475) 2026-05-10 12:24:22 +02:00
Steven Palma
5699f6cbf4 chore(ci): disable auto-stale (#3550) 2026-05-10 11:49:31 +02:00
masato-ka
0e6114ac36 fix(train): restrict legacy RA-BC migration to JSON checkpoints only (#3490)
* fix(train): restrict legacy RA-BC migration to JSON checkpoints only

_migrate_legacy_rabc_fields was called for all config files, causing
json.load to raise DecodeError when a YAML/TOML config was passed to
lerobot-train for a new training run. Guard the block with an
.endswith(".json") check so migration only runs when resuming from
a JSON checkpoint.
2026-05-08 20:27:01 +02:00
144 changed files with 9411 additions and 3949 deletions

View File

@@ -152,13 +152,14 @@ jobs:
BASE_VERSION="${VERSION%%-*}"
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
uv pip install \
--torch-backend cpu \
--index-url https://test.pypi.org/simple/ \
--extra-index-url https://pypi.org/simple \
--index-strategy unsafe-best-match \
"lerobot[all]==$BASE_VERSION"
else
echo "Installing release version $VERSION from PyPI..."
uv pip install "lerobot[all]==$VERSION"
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
fi
- name: Check lerobot version
run: uv run python -c "import lerobot; print(lerobot.__version__)"

View File

@@ -19,8 +19,8 @@ on:
workflow_dispatch:
# Runs at 02:00
schedule:
- cron: "0 2 * * *"
# schedule:
# - cron: "0 2 * * *"
env:
CLOSE_ISSUE_MESSAGE: >

View File

@@ -232,6 +232,8 @@ Match the policy to the user's **GPU memory** and **time budget**. Numbers below
All policies typically train for **510 epochs** (see §7).
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |

View File

@@ -109,7 +109,7 @@ lerobot-train \
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
## Inference & Evaluation

View File

@@ -1,288 +0,0 @@
# Video benchmark
## Questions
What is the optimal trade-off between:
- maximizing loading time with random access,
- minimizing memory space on disk,
- maximizing success rate of policies,
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
How to encode videos?
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
How to decode videos?
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
## Variables
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
**Data augmentations**
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
### Encoding parameters
| parameter | values |
| ----------- | ------------------------------------------------------------ |
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
| **pix_fmt** | `yuv444p`, `yuv420p` |
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
### Decoding parameters
**Decoder**
We tested two video decoding backends from torchvision:
- `pyav`
- `video_reader` (requires to build torchvision from source)
**Requested timestamps**
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
- `1_frame`: 1 frame,
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
## Metrics
**Data compression ratio (lower is better)**
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
**Loading time ratio (lower is better)**
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
**Average Mean Square Error (lower is better)**
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
**Average Peak Signal to Noise Ratio (higher is better)**
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
**Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
<!-- **Loss of a pretrained policy (higher is better)** (not available)
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
**Success rate after retraining (higher is better)** (not available)
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
## How the benchmark works
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
This gives a unique set of encoding parameters which is used to encode the episode.
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
These are then all concatenated to a single table ready for analysis.
## Caveats
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
Additional encoding parameters exist that are not included in this benchmark. In particular:
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio`
- `ffmpegio`
- `decord`
- `nvc`
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
## Install
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
## Adding a video decoder
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
You can easily add a new decoder to benchmark by adding it to this function in the script:
```diff
def decode_video_frames(
video_path: str,
timestamps: list[float],
tolerance_s: float,
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
+ elif backend == ["your_decoder"]:
+ return your_decoder_function(
+ video_path, timestamps, tolerance_s, backend
+ )
else:
raise NotImplementedError(backend)
```
## Example
For a quick run, you can try these parameters:
```bash
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 2 20 None \
--crf 10 40 None \
--timestamps-modes 1_frame 2_frames \
--backends pyav video_reader \
--num-samples 5 \
--num-workers 5 \
--save-frames 0
```
## Results
### Reproduce
We ran the benchmark with the following parameters:
```bash
# h264 and h265 encodings
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
--crf 0 5 10 15 20 25 30 40 50 None \
--timestamps-modes 1_frame 2_frames 6_frames \
--backends pyav video_reader \
--num-samples 50 \
--num-workers 5 \
--save-frames 1
# av1 encoding (only compatible with yuv420p and pyav decoder)
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libsvtav1 \
--pix-fmt yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
--crf 0 5 10 15 20 25 30 40 50 None \
--timestamps-modes 1_frame 2_frames 6_frames \
--backends pyav \
--num-samples 50 \
--num-workers 5 \
--save-frames 1
```
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
### Parameters selected for LeRobotDataset
Considering these results, we chose what we think is the best set of encoding parameter:
- vcodec: `libsvtav1`
- pix-fmt: `yuv420p`
- g: `2`
- crf: `30`
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
### Summary
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |

View File

@@ -1,488 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Assess the performance of video decoding in various configurations.
This script will benchmark different video encoding and decoding parameters.
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
"""
import argparse
import datetime as dt
import itertools
import random
import shutil
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import einops
import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
decode_video_frames,
encode_video_frames,
)
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.utils import TimerManager
BASE_ENCODING = OrderedDict(
[
("vcodec", "libx264"),
("pix_fmt", "yuv444p"),
("g", 2),
("crf", None),
# TODO(aliberts): Add fastdecode
# ("fastdecode", 0),
]
)
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
def parse_int_or_none(value) -> int | None:
if value.lower() == "none":
return None
try:
return int(value)
except ValueError as e:
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)
def get_directory_size(directory: Path) -> int:
total_size = 0
for item in directory.rglob("*"):
if item.is_file():
total_size += item.stat().st_size
return total_size
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
frame = torch.from_numpy(np.array(frame))
frame = frame.type(torch.float32) / 255
frame = einops.rearrange(frame, "h w c -> c h w")
frames.append(frame)
return torch.stack(frames)
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
return
save_dir.mkdir(parents=True, exist_ok=True)
for i, ts in enumerate(timestamps):
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
return
imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
if i >= ep_num_images - 1:
break
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
case "1_frame":
frame_indexes = [idx]
case "2_frames":
frame_indexes = [idx - 1, idx]
case "2_frames_4_space":
frame_indexes = [idx - 5, idx]
case "6_frames":
frame_indexes = [idx - i for i in range(6)][::-1]
case _:
raise ValueError(timestamps_mode)
return [idx / fps for idx in frame_indexes]
def benchmark_decoding(
imgs_dir: Path,
video_path: Path,
timestamps_mode: str,
backend: str,
ep_num_images: int,
fps: int,
num_samples: int = 50,
num_workers: int = 4,
save_frames: bool = False,
) -> dict:
def process_sample(sample: int, lock: Lock):
time_benchmark = TimerManager(log=False)
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
num_frames = len(timestamps)
result = {
"psnr_values": [],
"ssim_values": [],
"mse_values": [],
}
with time_benchmark, lock:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
with time_benchmark:
original_frames = load_original_frames(imgs_dir, timestamps, fps)
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
)
result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
)
if save_frames and sample == 0:
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
return result
load_times_video_ms = []
load_times_images_ms = []
mse_values = []
psnr_values = []
ssim_values = []
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
# Use a single shared lock for all worker threads
shared_lock = Lock()
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
psnr_values.extend(result["psnr_values"])
ssim_values.extend(result["ssim_values"])
mse_values.extend(result["mse_values"])
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
return {
"avg_load_time_video_ms": avg_load_time_video_ms,
"avg_load_time_images_ms": avg_load_time_images_ms,
"video_images_load_time_ratio": video_images_load_time_ratio,
"avg_mse": float(np.mean(mse_values)),
"avg_psnr": float(np.mean(psnr_values)),
"avg_ssim": float(np.mean(ssim_values)),
}
def benchmark_encoding_decoding(
dataset: LeRobotDataset,
video_path: Path,
imgs_dir: Path,
encoding_cfg: dict,
decoding_cfg: dict,
num_samples: int,
num_workers: int,
save_frames: bool,
overwrite: bool = False,
seed: int = 1337,
) -> list[dict]:
fps = dataset.fps
if overwrite or not video_path.is_file():
tqdm.write(f"encoding {video_path}")
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=encoding_cfg["vcodec"],
pix_fmt=encoding_cfg["pix_fmt"],
g=encoding_cfg.get("g"),
crf=encoding_cfg.get("crf"),
# fast_decode=encoding_cfg.get("fastdecode"),
overwrite=True,
)
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
images_size_bytes = get_directory_size(imgs_dir)
video_images_size_ratio = video_size_bytes / images_size_bytes
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
timestamps_mode,
backend,
ep_num_images,
fps,
num_samples,
num_workers,
save_frames,
)
benchmark_row.update(
**{
"repo_id": dataset.repo_id,
"resolution": f"{width} x {height}",
"num_pixels": num_pixels,
"video_size_bytes": video_size_bytes,
"images_size_bytes": images_size_bytes,
"video_images_size_ratio": video_images_size_ratio,
"timestamps_mode": timestamps_mode,
"backend": backend,
},
**encoding_cfg,
)
benchmark_table.append(benchmark_row)
return benchmark_table
def main(
output_dir: Path,
repo_ids: list[str],
vcodec: list[str],
pix_fmt: list[str],
g: list[int],
crf: list[int],
# fastdecode: list[int],
timestamps_modes: list[str],
backends: list[str],
num_samples: int,
num_workers: int,
save_frames: bool,
):
check_datasets_formats(repo_ids)
encoding_benchmarks = {
"g": g,
"crf": crf,
# "fastdecode": fastdecode,
}
decoding_benchmarks = {
"timestamps_modes": timestamps_modes,
"backends": backends,
}
headers = ["repo_id", "resolution", "num_pixels"]
headers += list(BASE_ENCODING.keys())
headers += [
"timestamps_mode",
"backend",
"video_size_bytes",
"images_size_bytes",
"video_images_size_ratio",
"avg_load_time_video_ms",
"avg_load_time_images_ms",
"video_images_load_time_ratio",
"avg_mse",
"avg_psnr",
"avg_ssim",
]
file_paths = []
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
benchmark_table = []
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
dataset = LeRobotDataset(repo_id)
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for duet in [
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
for unique_combination in itertools.product(*encoding_benchmarks.values())
]:
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
for key, value in duet.items():
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
imgs_dir,
encoding_cfg,
decoding_benchmarks,
num_samples,
num_workers,
save_frames,
)
# Save intermediate results
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
now = dt.datetime.now()
csv_path = (
output_dir
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
)
benchmark_df.to_csv(csv_path, header=True, index=False)
file_paths.append(csv_path)
del benchmark_df
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_df.to_csv(concatenated_path, header=True, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/video_benchmark"),
help="Directory where the video benchmark outputs are written.",
)
parser.add_argument(
"--repo-ids",
type=str,
nargs="*",
default=[
"lerobot/pusht_image",
"lerobot/aloha_mobile_shrimp_image",
"lerobot/paris_street",
"lerobot/kitchen",
],
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
)
parser.add_argument(
"--vcodec",
type=str,
nargs="*",
default=["h264", "hevc", "libsvtav1"],
help="Video codecs to be tested",
)
parser.add_argument(
"--pix-fmt",
type=str,
nargs="*",
default=["yuv444p", "yuv420p"],
help="Pixel formats (chroma subsampling) to be tested",
)
parser.add_argument(
"--g",
type=parse_int_or_none,
nargs="*",
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
help="Group of pictures sizes to be tested.",
)
parser.add_argument(
"--crf",
type=parse_int_or_none,
nargs="*",
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
help="Constant rate factors to be tested.",
)
# parser.add_argument(
# "--fastdecode",
# type=int,
# nargs="*",
# default=[0, 1],
# help="Use the fastdecode tuning option. 0 disables it. "
# "For libx264 and libx265/hevc, only 1 is possible. "
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
# )
parser.add_argument(
"--timestamps-modes",
type=str,
nargs="*",
default=[
"1_frame",
"2_frames",
"2_frames_4_space",
"6_frames",
],
help="Timestamps scenarios to be tested.",
)
parser.add_argument(
"--backends",
type=str,
nargs="*",
default=["torchcodec", "pyav"],
help="Torchvision decoding backend to be tested.",
)
parser.add_argument(
"--num-samples",
type=int,
default=50,
help="Number of samples for each encoding x decoding config.",
)
parser.add_argument(
"--num-workers",
type=int,
default=10,
help="Number of processes for parallelized sample processing.",
)
parser.add_argument(
"--save-frames",
type=int,
default=0,
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
)
args = parser.parse_args()
main(**vars(args))

View File

@@ -35,7 +35,7 @@ USER root
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-nvcc-12-6 cuda-cudart-dev-12-6 \
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
libvulkan1 vulkan-tools \
&& mkdir -p /usr/share/vulkan/icd.d \
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \

View File

@@ -18,7 +18,7 @@
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
# Configure the base image for CI with GPU access
ARG CUDA_VERSION=12.6.3
ARG CUDA_VERSION=12.8.1
ARG OS_VERSION=24.04
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}

View File

@@ -3,12 +3,14 @@
title: LeRobot
- local: installation
title: Installation
- local: cheat-sheet
title: Cheat sheet
title: Get started
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: bring_your_own_policies
title: Bring Your Own Policies
title: Adding a Policy
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
@@ -24,6 +26,12 @@
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
- sections:
- local: hardware_guide
title: Compute Hardware Guide
- local: torch_accelerators
title: PyTorch accelerators
title: "Compute & Hardware"
- sections:
- local: lerobot-dataset-v3
title: Using LeRobotDataset
@@ -33,6 +41,8 @@
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: video_encoding_parameters
title: Video encoding parameters
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
@@ -142,10 +152,6 @@
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
title: "Supported Hardware"
- sections:
- local: notebooks
title: Notebooks

View File

@@ -90,6 +90,6 @@ lerobot-record \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=${HF_USER}/act_policy
```

View File

@@ -1,60 +1,37 @@
# Bring Your Own Policies
# Adding a Policy
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
## Step 1: Create a Policy Package
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
### Package Structure
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
```bash
lerobot_policy_my_custom_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_custom_policy/
├── __init__.py
├── configuration_my_custom_policy.py
├── modeling_my_custom_policy.py
└── processor_my_custom_policy.py
```
---
### Package Configuration
## Anatomy of a policy
Set up your `pyproject.toml`:
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
```toml
[project]
name = "lerobot_policy_my_custom_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.12"
### Configuration class
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
## Step 2: Define the Policy Configuration
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
```python
# configuration_my_custom_policy.py
# configuration_my_policy.py
from dataclasses import dataclass, field
from lerobot.configs import PreTrainedConfig
from lerobot.optim import AdamWConfig
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("my_custom_policy")
@PreTrainedConfig.register_subclass("my_policy")
@dataclass
class MyCustomPolicyConfig(PreTrainedConfig):
"""Configuration class for MyCustomPolicy.
class MyPolicyConfig(PreTrainedConfig):
"""Configuration class for MyPolicy.
Args:
n_obs_steps: Number of observation steps to use as input
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
raise ValueError("n_action_steps cannot exceed horizon")
def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
"""Validate input/output feature compatibility.
Call this explicitly from your policy's __init__ — the base class does not.
"""
if not self.image_features:
raise ValueError("MyCustomPolicy requires at least one image feature.")
raise ValueError("MyPolicy requires at least one image feature.")
if self.action_feature is None:
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
raise ValueError("MyPolicy requires 'action' in output_features.")
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
def get_scheduler_preset(self):
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
return None
@property
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Relative timestep offsets for the action chunk the dataset loader returns.
"""
"""Relative timestep offsets for the action chunk the dataset loader returns."""
return list(range(self.horizon))
@property
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
return None
```
## Step 3: Implement the Policy Class
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
### Policy class
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
```python
# modeling_my_custom_policy.py
# modeling_my_policy.py
import torch
import torch.nn as nn
from typing import Any
from lerobot.policies import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .configuration_my_policy import MyPolicyConfig
class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
name = "my_custom_policy"
class MyPolicy(PreTrainedPolicy):
config_class = MyPolicyConfig # must match the string in @register_subclass
name = "my_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
config.validate_features() # not called automatically by the base class
self.config = config
self.model = ... # your nn.Module here
def reset(self):
"""Reset episode state."""
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
...
def get_optim_params(self) -> dict:
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
...
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
"""Return a single action for the current timestep (called at inference)."""
"""Return a single action for the current timestep (called every step at inference)."""
...
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
"""Compute the training loss.
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
logging-friendly Python natives (no tensors with gradients).
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
timesteps padded because the episode ended before `horizon` steps, you
timesteps padded because the episode ended before `horizon` steps; you
can exclude those from your loss.
"""
actions = batch[ACTION]
action_is_pad = batch.get("action_is_pad")
...
return {"loss": ...}
return loss, {"some_loss_component": some_loss_component.item()}
```
## Step 4: Add Data Processors
The methods called by the train/eval loops:
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
| Method | Used by | What it does |
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
### Processor functions
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
```python
# processor_my_custom_policy.py
# processor_my_policy.py
from typing import Any
import torch
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
def make_my_custom_policy_pre_post_processors(
def make_my_policy_pre_post_processors(
config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
return preprocessor, postprocessor
```
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
**Important function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
## Step 5: Package Initialization
---
Expose your classes in the package's `__init__.py`:
## Path A: Out-of-tree plugin
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
### Package structure
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
```bash
lerobot_policy_my_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_policy/
├── __init__.py
├── configuration_my_policy.py
├── modeling_my_policy.py
└── processor_my_policy.py
```
### `pyproject.toml`
```toml
[project]
name = "lerobot_policy_my_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.12"
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
### Package `__init__.py`
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
```python
# __init__.py
@@ -204,44 +239,148 @@ except ImportError:
"lerobot is not installed. Please install lerobot to use this policy package."
)
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .modeling_my_custom_policy import MyCustomPolicy
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
from .configuration_my_policy import MyPolicyConfig
from .modeling_my_policy import MyPolicy
from .processor_my_policy import make_my_policy_pre_post_processors
__all__ = [
"MyCustomPolicyConfig",
"MyCustomPolicy",
"make_my_custom_policy_pre_post_processors",
"MyPolicyConfig",
"MyPolicy",
"make_my_policy_pre_post_processors",
]
```
## Step 6: Installation and Usage
### Install Your Policy Package
### Install and use
```bash
cd lerobot_policy_my_custom_policy
cd lerobot_policy_my_policy
pip install -e .
# Or install from PyPI if published
pip install lerobot_policy_my_custom_policy
pip install lerobot_policy_my_policy
```
### Use Your Policy
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
```bash
lerobot-train \
--policy.type my_custom_policy \
--policy.type my_policy \
--env.type pusht \
--steps 200000
```
## Examples and Community Contributions
---
## Path B: Contributing in-tree
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
### In-tree layout
```
src/lerobot/policies/my_policy/
├── __init__.py # re-exports config + modeling + processor factory
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
├── processor_my_policy.py # make_my_policy_pre_post_processors
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
```
Two notes:
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
### Wiring
Three places need to know about your policy. All by name.
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
Mirror an existing policy that's structurally similar to yours; the diff is small.
### Heavy / optional dependencies
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
```python
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
else:
DDIMScheduler = None # keeps the symbol bindable at import time
class DiffusionPolicy(PreTrainedPolicy):
def __init__(self, config):
require_package("diffusers", extra="diffusion")
super().__init__(config)
...
```
This way:
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
- Type checkers see the real symbol.
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
### Benchmarks and a published checkpoint
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
```markdown
## Results
Evaluated on LIBERO with `lerobot/<policy>_libero`:
| Suite | Success rate | n_episodes |
| -------------- | -----------: | ---------: |
| libero_spatial | 87.5% | 50 |
| libero_object | 93.0% | 50 |
| libero_goal | 81.5% | 50 |
| libero_10 | 62.0% | 50 |
| **average** | **81.0%** | 200 |
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
```
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
### PR checklist
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
---
## Examples and community contributions
Check out these example policy implementations:
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
Share your policy implementations with the community! 🤗
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗

139
docs/source/cheat-sheet.mdx Normal file
View File

@@ -0,0 +1,139 @@
# Cheat sheet
All of the LeRobot commands in one place. If you forgot how to use a specific command or want to learn about a new one you can do it here.
> [!WARNING]
> For all of the commands listed below remember to change the ports/names/ids to your own values!
> [!TIP]
> Another great way to look at all the commands and get them configured for your specific setup is to use this [Jupyter Notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb).
### Setup and installation
For installation please look at [LeRobot Installation](https://huggingface.co/docs/lerobot/main/en/installation).
### Useful tools
###### Find port
Use this to identify which serial ports your robots are connected to. Follow the instructions in your terminal: you will be asked to unplug the USB cable and press Enter. The script will then detect and print the correct serial port for that robot.
```bash
lerobot-find-port
```
###### Find cameras
Quickly find camera indices and verify their output. This command prints camera information to the terminal and saves test frames from each detected camera to `lerobot/outputs/captured_images`
```bash
lerobot-find-cameras
```
### Calibration
In most cases you will need to perform calibration just once for each robot and teleoperation device. Before performing the calibration make sure that all the joints are roughly in the middle position.
```bash
lerobot-calibrate \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_follower_arm
```
Make sure that you use the same IDs used during calibration later for the other scripts. That's how LeRobot finds the calibration files.
### Teleoperation
Teleoperating with two cameras and displaying the data with Rerun.
```bash
lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_follower_arm \
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--teleop.id=my_leader_arm \
--display_data=true
```
### Recording a dataset
The dataset is automatically uploaded to the server and saved under repo_id, make sure you are logged in to your HF account with CLI:
`hf auth login`
You can get the token from: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
```bash
lerobot-record \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_follower_arm \
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--teleop.id=my_leader_arm \
--dataset.repo_id=${HF_USER}/so101_dataset_test \
--dataset.num_episodes=30 \
--dataset.single_task="put the red brick in a bowl" \
--dataset.streaming_encoding=true \
--display_data=true
```
While collecting the dataset you can control the process with your keyboard:
Control the data recording flow using keyboard shortcuts:
- Press **Right Arrow (`→`)**: Save episode and move to the next.
- Press **Left Arrow (`←`)**: Delete current episode and retry.
- Press **Escape (`ESC`)**: Stop, encode videos, and upload.
### Training
Depending on your hardware training the policy might take a few hours. That's how you train simple `ACT` policy:
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/so101_dataset_test \
--policy.type=act \
--output_dir=outputs/train/act_so101_test \
--job_name=act_so101_test \
--policy.device=cuda \
--wandb.enable=true \
--policy.repo_id=${HF_USER}/policy_test \
--steps=20000
```
- Policy Types: `act`, `diffusion`, `smolvla`, `pi05`
- Devices: `cuda` (NVIDIA), `mps` (Apple Silicon), `cpu`
If you want to fine-tune a specific model you can provide the path to the model. In this case path is enough and type can be skipped.
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/so101_dataset_test \
--policy.path=username/the_policy_to_finetune \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/policy_test \
--output_dir=outputs/train/act_so101_test \
--steps=20000
```
### Inference
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
> [!TIP]
> If you are using the previous release V0.5.1 instead of `lerobot-rollout` you need to use `lerobot-record`. More information [here](https://huggingface.co/docs/lerobot/v0.5.1/en/il_robots#run-inference-and-evaluate-your-policy).
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
--task="Put lego brick into the transparent box" \
--duration=60
```

View File

@@ -194,7 +194,7 @@ lerobot-record \
--dataset.single_task="Navigate around obstacles" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--display_data=true
```

View File

@@ -123,7 +123,7 @@ lerobot-record \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10

View File

@@ -0,0 +1,98 @@
# Compute HW Guide for LeRobot Training
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
## Memory by policy group
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30100% over a forward+backward pass alone.
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
| Light BC | `act`, `vqbet`, `tdmpc` | ~26GB | Laptop GPU (RTX 3060), L4, A10G |
| Diffusion | `diffusion`, `multi_task_dit` | ~814GB | RTX 4070+ / L4 / A10G |
| Small VLA | `smolvla` | ~1016GB | RTX 4080+ / L4 / A10G |
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~2440GB | A100 40 GB+ (24 GB tight at BS 1) |
| Multimodal | `groot`, `eo1` | ~2440GB | A100 40 GB+ |
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
## Training time
Robotics imitation learning typically converges in **510 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
```text
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
total_steps = epochs × steps_per_epoch
wall_clock ≈ total_steps × per_step_time
```
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
### Common scenarios
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
| Setup | Policy | Batch | Wall-clock |
| ------------------------------------ | -------------- | ----- | ---------: |
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~3060min |
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~24h |
| Single L4 / A10G (24 GB) | `act` | 8 | ~12h |
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~36h |
| Single A100 40 GB | `smolvla` | 16 | ~12h |
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~48h |
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~3060min |
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~12h |
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~614h |
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
### Multi-GPU matters a lot
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
### Schedule and checkpoints
If you shorten training (e.g. 5k10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
## Where to run
VRAM is the first filter. Within a tier, pick by budget and availability — the `$``$$$$` columns are relative; check current pricing on the provider you actually use.
| Class | VRAM | Tier | Comfortable for |
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
| L4 / A10G (cloud) | 24 GB | `$$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
### Hugging Face Jobs
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
```bash
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
bash -c "nvidia-smi && lerobot-train \
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
```
Notes:
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).

View File

@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
### Understanding Configuration
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
<!-- prettier-ignore-start -->
```python
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
```
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
**Example Configuration**
```json
"end_effector_bounds": {
"max": [0.24, 0.20, 0.10],
"min": [0.16, -0.08, 0.03]
{
"env": {
"processor": {
"inverse_kinematics": {
"end_effector_bounds": {
"max": [0.24, 0.2, 0.1],
"min": [0.16, -0.08, 0.03]
}
}
}
}
}
```
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
<!-- prettier-ignore-start -->
```python
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
"""Configuration for the SO100FollowerEndEffector robot."""
class InverseKinematicsConfig:
"""Configuration for inverse kinematics processing."""
# Default bounds for the end-effector position (in meters)
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
default_factory=lambda: {
"min": [-1.0, -1.0, -1.0], # min x, y, z
"max": [1.0, 1.0, 1.0], # max x, y, z
}
)
urdf_path: str | None = None
target_frame_name: str | None = None
# bounds for the end-effector in x,y,z direction
end_effector_bounds: dict[str, list[float]] | None = None
# maximum step size for the end-effector in x,y,z direction
end_effector_step_sizes: dict[str, float] | None = None
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
default_factory=lambda: {
"x": 0.02,
"y": 0.02,
"z": 0.02,
}
)
class HILSerlProcessorConfig:
...
# maximum gripper position that the gripper will be open at
max_gripper_pos: float | None = 100.0
```
<!-- prettier-ignore-end -->
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
**Collecting a Dataset for the reward classifier**
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
@@ -658,7 +661,7 @@ Example configuration section for data collection:
},
"dataset": {
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"root": "data/your_dataset",
"task": "reward_classifier_task",
"num_episodes_to_record": 20,
"replay_episode": null,
@@ -671,7 +674,7 @@ Example configuration section for data collection:
**Reward Classifier Configuration**
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
- **model_type**: `"cnn"` or `"transformer"`
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
"repo_id": "hf_username/dataset_name",
"root": null
},
"policy": {
"reward_model": {
"type": "reward_classifier",
"model_name": "helper2424/resnet10",
"model_type": "cnn",
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
"dropout_rate": 0.1,
"learning_rate": 1e-4,
"device": "cuda",
"use_amp": true,
"input_features": {
"observation.images.front": {
"type": "VISUAL",
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
**Configuration Setup**
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
1. Configure the policy settings (`type="sac"`, `device`, etc.)
2. Set `dataset` to your cropped dataset
3. Configure environment settings with crop parameters
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
3. Set `dataset` to your cropped dataset
4. Configure environment settings with crop parameters
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
**Starting the Learner**
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
Some configuration values have a disproportionate impact on training stability and speed:
- **`temperature_init`** (`policy.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`temperature_init`** (`algorithm.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
- **`storage_device`** (`policy.storage_device`) device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.

View File

@@ -232,7 +232,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--display_data=true
```
@@ -278,6 +278,6 @@ lerobot-record \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
```

View File

@@ -193,7 +193,7 @@ lerobot-record \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--dataset.encoder_threads=2
```
</hfoption>

View File

@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
### PyTorch CUDA variant (Linux only)
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
<!-- prettier-ignore-start -->
<hfoptions id="cuda_variant">
<hfoption id="uv-source">
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
To override for a different CUDA variant:
```bash
uv pip install --force-reinstall torch torchvision \
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
```
</hfoption>
<hfoption id="pip-conda">
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
To pick a specific CUDA variant:
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
```bash
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
pip install -e ".[all]" # source
# — or —
pip install lerobot # from PyPI
```
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
```bash
uv pip install --torch-backend cu128 lerobot
```
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
### Troubleshooting
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.

View File

@@ -10,6 +10,7 @@ This docs will guide you to:
- Stream datasets without downloading using `StreamingLeRobotDataset`
- Apply image transforms for data augmentation during training
- Migrate existing `v2.1` datasets to `v3.0`
- Experiment with other `LeRobotDataset` formats and implementations like Lance
## Whats new in `v3`
@@ -43,7 +44,7 @@ lerobot-record \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--dataset.encoder_threads=2
```
@@ -315,3 +316,39 @@ Dataset v3.0 uses incremental parquet writing with buffered metadata for efficie
- Ensures the dataset is valid for loading
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
## Other formats and implementations
### Lance
Lance is a useful format for multimodal AI datasets, especially for large-scale training requiring high performance IO and random access.
The `lerobot-lancedb` package implements `LeRobotLanceDataset` (for JPEG images) and `LeRobotLanceVideoDataset` (for mp4 videos).
Those two storage layouts both subclass LeRobotDataset and can provide data loading speed ups.
`LeRobotLanceDataset` is a drop-in replacement for `LeRobotDataset`:
```python
from lerobot.datasets import LeRobotDatasetMetadata
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot_lancedb import LeRobotLanceDataset, LeRobotLanceVideoDataset
cfg = DiffusionConfig(...)
meta = LeRobotDatasetMetadata(root=local_dataset_path) # or use repo_id=... to load metadata from the Hub
delta_timestamps = {...}
# Use LeRobotLanceDataset for image datasets
dataset = LeRobotLanceDataset(
root=local_dataset_path, # or use repo_id=... to stream from the Hub
delta_timestamps=delta_timestamps,
return_uint8=True,
)
# Or use LeRobotLanceVideoDataset for video datasets:
dataset = LeRobotLanceVideoDataset(
root=local_dataset_path, # or use repo_id=... to stream from the Hub
delta_timestamps=delta_timestamps,
return_uint8=True,
)
```
Join the discussion on [Github](https://github.com/huggingface/lerobot/issues/3608) and explore the `lerobot-lancedb` documentation [here](https://lancedb.github.io/lerobot-lancedb/).

View File

@@ -28,13 +28,15 @@ lerobot-train \
--steps=100000 \
--batch_size=32 \
--peft.method_type=LORA \
--peft.r=64
--peft.r=64 \
--peft.lora_alpha=64
```
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter, and the LoRA scaling factor with
`--peft.lora_alpha` (where `scaling = lora_alpha / r`). The higher the rank
the closer you get to full fine-tuning
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue

View File

@@ -161,7 +161,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--display_data=true
```
@@ -203,7 +203,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
--display_data=true
```

View File

@@ -108,7 +108,7 @@ lerobot-record \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \

View File

@@ -17,9 +17,9 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
| Parameter | CLI Flag | Type | Default | Description |
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `vcodec` | `--dataset.camera_encoder.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `30` | Max buffered frames per camera (~1s at 30fps). Consumes RAM |
## 3. Performance Considerations
@@ -48,7 +48,7 @@ This parameter controls how many threads each encoder instance uses internally:
### Backpressure and Frame Dropping
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
Each camera has a bounded queue (`encoder_queue_maxsize`, default 30 frames). When the encoder can't keep up:
1. The queue fills up (consuming RAM)
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
@@ -82,15 +82,15 @@ Use HW encoding when:
### Available HW Encoders
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | --------------------------------------------------- |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder.vcodec=auto` |
> [!NOTE]
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
@@ -100,15 +100,15 @@ Use HW encoding when:
## 5. Troubleshooting
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
## 6. Recommended Configurations
@@ -146,7 +146,7 @@ On very constrained systems, streaming encoding may compete too heavily with the
# 2camsx 640x480x3 @30fps: Requires some tuning.
# Use H.264, disable streaming, consider batching encoding
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
lerobot-record --dataset.camera_encoder.vcodec=h264 --dataset.streaming_encoding=false ...
```
## 7. Closing note

View File

@@ -117,10 +117,10 @@ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
--operation.g 2 \
--operation.crf 30
--operation.camera_encoder.vcodec libsvtav1 \
--operation.camera_encoder.pix_fmt yuv420p \
--operation.camera_encoder.g 2 \
--operation.camera_encoder.crf 30
# Convert only specific episodes
lerobot-edit-dataset \
@@ -147,11 +147,7 @@ lerobot-edit-dataset \
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
- `fast_decode`: Fast decode tuning option (default: 0)
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)

View File

@@ -0,0 +1,117 @@
# Video encoding parameters
When video storage is enabled, LeRobot stores each camera stream as an **MP4** file instead of saving one image file per timestep. Video encoding compresses across time, which usually cuts dataset size and I/O compared to a pile of PNG, while keeping MP4 — a format every player and loader understands.
Encoding frames into an MP4 is a full FFmpeg pipeline: choice of encoder, pixel format, GOP/keyframes, quality vs. speed, and optional extra encoder flags. Most of these knobs are user-tunable through `camera_encoder`, a nested `VideoEncoderConfig` (`lerobot.configs.video.VideoEncoderConfig`) passed through PyAV.
You can set these parameters from the CLI with `--dataset.camera_encoder.<field>` (e.g. with `lerobot-record` or `lerobot-rollout`). The same block applies to every camera video stream in that run.
<Tip>
Video storage must be on for `camera_encoder` to have any effect —
`use_videos=True` in Python APIs, or `--dataset.video=true` on the CLI (the
recording default). With video off, inputs stay as images and `camera_encoder`
is ignored.
</Tip>
For details on **when** frames are written vs. encoded (streaming vs. post-episode), queues, and other top-level `--dataset.*` switches, see [Streaming Video Encoding](./streaming_video_encoding). For an encoding-parameter comparison and experiments, see the [video-benchmark Space](https://huggingface.co/spaces/lerobot/video-benchmark).
---
## Example
```bash
lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--robot.id=black \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue \
--dataset.repo_id=<my_username>/<my_dataset_name> \
--dataset.num_episodes=2 \
--dataset.single_task="Grab the cube" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
--dataset.camera_encoder.vcodec=h264 \
--dataset.camera_encoder.preset=fast \
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \
--display_data=true
```
---
## Tuning parameters
<Tip warning={true}>
The defaults are tuned to balance **compression ratio**, **visual quality**, and **decoding/seek speed** for typical robotics datasets. Changing them can affect both recording (CPU load, frame drops) and training (decoding throughput, image quality).
Only override these parameters if you have a specific reason to, and measure the impact on your pipeline before relying on the new settings.
</Tip>
All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
| Parameter | Type | Default | Description |
| --------------- | ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `vcodec` | `str` | `"libsvtav1"` | Video codec name. `"auto"` picks the first available hardware encoder from a fixed preference list, falling back to `libsvtav1`. |
| `pix_fmt` | `str` | `"yuv420p"` | Output pixel format. Must be supported by the chosen codec in your FFmpeg build. |
| `g` | `int` | `2` | GOP size — a keyframe every `g` frames. Emitted as FFmpeg option `g`. |
| `crf` | `int` or `float` | `30` | Abstract quality value, mapped per codec (see the [mapping](#mapping-videoencoderconfig--ffmpeg-options) below). Lower → higher quality / larger output where the mapping is monotone. |
| `preset` | `int` or `str` | `12` \* | Encoder speed preset; meaning depends on the codec. <br/>\* When unset and `vcodec=libsvtav1`, LeRobot defaults to `12`. |
| `fast_decode` | `int` | `0` | `libsvtav1`: `02`, passed via `svtav1-params`. <br/>`h264` / `hevc` (software): if `>0`, sets `tune=fastdecode`. <br/>Other codecs: usually unused. |
| `video_backend` | `str` | `"pyav"` | Only `"pyav"` is currently implemented for video encoding. |
| `extra_options` | `dict` | `{}` | Extra FFmpeg or codec specific options merged after the structured fields above. Cannot override keys already set by those fields. |
---
## Persistence in dataset metadata
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
```json
{
"features": {
"observation.images.laptop": {
"dtype": "video",
"shape": [480, 640, 3],
"info": {
"video.height": 480,
"video.width": 640,
"video.codec": "h264",
"video.pix_fmt": "yuv420p",
"video.fps": 30,
"video.channels": 3,
"is_depth_map": false,
"video.g": 2,
"video.crf": 30,
"video.preset": "fast",
"video.fast_decode": 0,
"video.video_backend": "pyav",
"video.extra_options": { "tune": "film", "profile:v": "high", "bf": 2 }
}
}
}
}
```
Two sources contribute to the `info` block:
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
<Tip>
This block is populated **once**, from the **first** episode. It assumes every
episode in the dataset was encoded with the same `camera_encoder`. Changing
encoder settings partway through a recording is not supported — the
`info.json` will only reflect the parameters used for the first episode.
</Tip>
---
## Merging datasets
When aggregating datasets with `merge_datasets`, video files are concatenated as-is (no re-encoding), and encoder fields in `info.json` are merged per-key:
- **Stream-derived fields must match** across sources: `video.codec`, `video.pix_fmt`, `video.height`, `video.width`, `video.fps`. Otherwise FFmpeg's concat demuxer fails.
- **Encoder-tuning fields are merged loosely**: `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.extra_options`. If every source agrees, the value is kept; if not, it's set to `null` (or `{}` for `video.extra_options`) and a warning is logged.

View File

@@ -80,7 +80,7 @@
"}\n",
"\n",
"# Dataset\n",
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
"HF_USER = \"your_hf_username\" # `hf auth whoami` to find your username\n",
"DATASET_NAME = \"my_so101_dataset\"\n",
"TASK_DESCRIPTION = \"pick and place the block\"\n",
"NUM_EPISODES = 10\n",
@@ -291,7 +291,34 @@
"\n",
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
"\n",
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details.\n",
"\n",
"Recently ```lerobot-rollout``` was introduced, you can [read more about it here](https://huggingface.co/docs/lerobot/main/en/il_robots?eval=Base+mode+%28no+recording%29#run-inference-and-evaluate-your-policy)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-rollout\",\n",
" \"--strategy.type=base\",\n",
" f\"--policy.path={POLICY_PATH}\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" CAMERAS_FLAG,\n",
" f'--task=\"{TASK_DESCRIPTION}\"',\n",
" \"--duration=60\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"if you are using the V0.5.1 release you should use ```lerobot-record``` instead of rollout"
]
},
{

136
examples/omx/README.md Normal file
View File

@@ -0,0 +1,136 @@
# OMX Follower — Cube Pick And Place Example
This is an example of what is possible to do with LeRobot on a physical setup.
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
## Hardware
| Component | Value |
| --------- | ------------------------------------ |
| Robot | OMX Follower |
| Cameras | 2× OpenCV cameras (wrist + top-down) |
## Scripts
| Script | Purpose |
| ---------------------- | --------------------------------------------------------------- |
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
## Setup
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
```bash
export ROBOT_PORT=/dev/ttyACM0
export TELEOP_PORT=/dev/ttyACM1
export HF_USERNAME=<your_hf_username>
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
```
## Step 1 — Collect Data
```bash
lerobot-record \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--teleop.type=omx_leader \
--teleop.port=$TELEOP_PORT \
--teleop.id=omx_leader \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--dataset.root=data/omx_pickandplace \
--dataset.num_episodes=50 \
--dataset.single_task="Pick the cube and place it in the blue square" \
--dataset.streaming_encoding=true \
--dataset.push_to_hub=true
```
### Bonus Auto-Collect script
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
```bash
python -m examples.omx.record_grab \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--dataset.root=data/omx_pickandplace \
--dataset.num_episodes=50 \
--dataset.single_task="Pick the cube and place it in the blue square" \
--dataset.streaming_encoding=true \
--dataset.push_to_hub=true
```
Each episode:
1. The arm grabs the cube from the center of the workspace and places it at a random position.
2. The arm returns to HOME.
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
## Step 2 — Train
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
```bash
lerobot-train \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--policy.type=act \
--output_dir=outputs/train/omx_pickandplace_act \
--policy.device=cuda \
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
--steps=20000 \
--wandb.enable=true
```
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
## Step 3 — Rollout
Use the `lerobot-rollout` CLI with base strategy:
```bash
lerobot-rollout \
--strategy.type=base \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--policy.path=$HF_USERNAME/omx_pickandplace_act \
```
For continuous recording with automatic upload (sentry mode):
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=10 \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--policy.path=$HF_USERNAME/omx_pickandplace_act \
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
```
## Environment Reset Utility
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
`reset_environment.py` can be run standalone to prepare the workspace:
```bash
# Grab cube + place it at a random position on the left side
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
```
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.

422
examples/omx/record_grab.py Normal file
View File

@@ -0,0 +1,422 @@
#!/usr/bin/env python3
"""
Auto-record grab episodes for the OMX robot arm.
Each episode cycle:
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
2. HOME — return arm to home with gripper open
3. record_grab — execute a targeted grab to the stored position while recording
observations + actions to a LeRobotDataset
Usage (run from repo root):
python -m examples.omx.record_grab \\
--robot.type=omx_follower \\
--robot.port=/dev/ttyACM0 \\
--robot.id=omx_follower \\
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
--dataset.repo_id=<hf_username>/<dataset_name> \\
--dataset.root=data/omx_grab \\
--dataset.num_episodes=50 \\
--dataset.single_task="Grab the cube" \\
--dataset.streaming_encoding=true
"""
import logging
from dataclasses import dataclass
from pprint import pformat
import numpy as np
from lerobot.cameras import CameraConfig # noqa: F401
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.datasets import (
LeRobotDataset,
VideoEncodingManager,
aggregate_pipeline_dataset_features,
create_initial_features,
)
from lerobot.processor import make_default_processors
from lerobot.robots import RobotConfig, make_robot_from_config
from lerobot.robots.omx_follower import OmxFollower
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from .reset_environment import (
APPROACH_SPEED,
GRIPPER_CLOSE_POS,
HOME_POSE,
PUSH_END_ELBOW_FLEX,
PUSH_END_SHOULDER_LIFT,
PUSH_START_ELBOW_FLEX,
PUSH_START_SHOULDER_LIFT,
array_to_pose,
grab_cube,
horizontal_wrist_flex,
move_to_pose,
place_cube,
pose_to_array,
)
# ── Grab-episode motion parameters ────────────────────────────────────────────
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
GRAB_RAISE_SL_OFFSET = 20.0
GRAB_LOWER_SPEED = 20.0
RECORD_SPEED = 30.0
# Pose the arm travels to after closing the gripper (cube held).
GRAB_CARRY_POSE = {
"shoulder_pan.pos": -23.0,
"shoulder_lift.pos": 5.0,
"elbow_flex.pos": 18.0,
"wrist_flex.pos": -14.0,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
}
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
# Cube-approach and carry poses are never jittered to preserve precision.
_JITTER_LIMITS: dict[str, float] = {
"shoulder_pan.pos": 5.0,
"shoulder_lift.pos": 4.0,
"elbow_flex.pos": 4.0,
"wrist_flex.pos": 3.0,
"wrist_roll.pos": 2.0,
"gripper.pos": 0.0,
}
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
"""Return a copy of pose with independent per-joint random perturbations."""
return {
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
}
def _random_stuck_pose(rng: np.random.Generator) -> dict:
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
table-safe envelope across the full sl range:
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
sl= 0 → ef ∈ [-25, 25] (mid reach)
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
wrist_flex is randomly offset from the horizontal value.
"""
pan = float(rng.uniform(-5.0, 35.0))
sl = float(rng.uniform(-50.0, 30.0))
if sl <= 0.0:
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
ef_lo = alpha * -25.0 # 0 → -25
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
else:
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
ef = float(rng.uniform(ef_lo, ef_hi))
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf,
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
"gripper.pos": GRIPPER_CLOSE_POS,
}
logger = logging.getLogger(__name__)
@dataclass
class OmxRecordGrabConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Resume recording on an existing dataset.
resume: bool = False
# Fraction of episodes that start from a random stuck pose (gripper closed) to
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
recovery_prob: float = 0.5
def record_episode_spline(
robot: OmxFollower,
waypoints: list[dict],
speeds: list[float],
dataset: LeRobotDataset,
task: str,
) -> None:
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
Segment durations are parameterized from the maximum absolute joint delta
between consecutive waypoints divided by the requested segment speed,
producing non-uniform timing in joint space. Interior tangents are derived
from the adjacent per-segment velocities, with clamped (zero-velocity)
endpoints so the arm starts and stops smoothly. Each segment is cubic
Hermite, giving C1 continuity at every waypoint.
"""
pts = [pose_to_array(w) for w in waypoints]
n = len(pts)
# Steps and duration per segment
n_steps_list = []
timestamps = []
for i in range(n - 1):
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
n_steps_list.append(ns)
timestamps.append(ns / dataset.fps)
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
vels = [np.zeros_like(pts[0])]
for i in range(1, n - 1):
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
vels.append(0.5 * (v_prev + v_next))
vels.append(np.zeros_like(pts[0]))
dt = 1.0 / dataset.fps
for seg in range(n - 1):
ns = n_steps_list[seg]
if ns == 0:
continue
p0, p1 = pts[seg], pts[seg + 1]
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
m0 = vels[seg] * timestamps[seg]
m1 = vels[seg + 1] * timestamps[seg]
for step in range(1, ns + 1):
t = step / ns
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
h01 = -2 * t**3 + 3 * t**2
h11 = t**3 - t**2
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
action = array_to_pose(commanded)
robot.send_action(action)
obs = robot.get_observation()
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
dataset.add_frame({**obs_frame, **action_frame, "task": task})
precise_sleep(dt)
def record_grab_episode(
robot: OmxFollower,
dataset: LeRobotDataset,
pan: float,
t: float,
task: str,
recovery_start: bool = False,
) -> None:
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
Normal sequence (initial HOME move is NOT recorded):
HOME → raised approach above cube → lower → close gripper
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
(gripper closed) without recording, then recording begins from there:
stuck_pose → raised approach above cube → [normal grab sequence from there]
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
"""
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
sl_raised = sl - GRAB_RAISE_SL_OFFSET
wf_horizontal = horizontal_wrist_flex(sl, ef)
rng = np.random.default_rng()
if recovery_start:
stuck_pose = _random_stuck_pose(rng)
logger.info(f"Recovery start: {stuck_pose}")
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
first_waypoints = [stuck_pose]
first_speeds = []
else:
jittery_start = _jitter_pose(HOME_POSE, rng)
move_to_pose(robot, jittery_start, APPROACH_SPEED)
first_waypoints = [jittery_start]
first_speeds = []
waypoints = first_waypoints + [
{ # raised approach: arm above cube
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl_raised,
"elbow_flex.pos": ef,
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{ # lower onto cube — no jitter: precision needed
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf_horizontal,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{ # close gripper — no jitter: precision needed
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf_horizontal,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
_jitter_pose(
{ # raise with cube
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl_raised,
"elbow_flex.pos": ef,
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
rng,
),
_jitter_pose(
{ # retract: fold arm toward HOME before sweeping to carry zone
"shoulder_pan.pos": pan * 0.25,
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
"wrist_flex.pos": 0.0,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
rng,
),
GRAB_CARRY_POSE, # no jitter: target drop zone
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
HOME_POSE,
]
speeds = first_speeds + [
RECORD_SPEED, # (HOME →) raised approach
GRAB_LOWER_SPEED, # raised approach → lower
GRAB_LOWER_SPEED, # lower → close gripper
RECORD_SPEED, # close gripper → raise
RECORD_SPEED, # raise → retract
RECORD_SPEED, # retract → carry pose
RECORD_SPEED, # carry pose → drop
RECORD_SPEED, # drop → HOME
]
record_episode_spline(robot, waypoints, speeds, dataset, task)
# Dwell at HOME for ~0.5 s before next episode
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
dt = 1.0 / dataset.fps
for _ in range(int(dataset.fps * 0.5)):
robot.send_action(HOME_POSE)
obs = robot.get_observation()
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
dataset.add_frame({**obs_frame, **home_action, "task": task})
precise_sleep(dt)
@parser.wrap()
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger.info(pformat(cfg))
robot = make_robot_from_config(cfg.robot)
use_videos = cfg.dataset.video
teleop_action_processor, _, robot_obs_processor = make_default_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=robot.action_features),
use_videos=use_videos,
),
aggregate_pipeline_dataset_features(
pipeline=robot_obs_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=use_videos,
),
)
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
dataset = None
try:
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
streaming_encoding=cfg.dataset.streaming_encoding,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
else:
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=use_videos,
streaming_encoding=cfg.dataset.streaming_encoding,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
robot.connect(calibrate=True)
rng = np.random.default_rng()
with VideoEncodingManager(dataset):
for episode_idx in range(cfg.dataset.num_episodes):
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
logger.info("Step 1: grabbing and placing cube...")
grab_cube(robot)
pan, t = place_cube(robot)
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
record_grab_episode(
robot,
dataset,
pan,
t,
cfg.dataset.single_task,
recovery_start=recovery_start,
)
dataset.save_episode()
logger.info(f"Episode {episode_idx + 1} saved.")
finally:
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
if __name__ == "__main__":
record_grab()

View File

@@ -0,0 +1,267 @@
#!/usr/bin/env python3
"""
Auto-reset and cube-grab utility for the OMX robot arm.
Provides:
- grab_cube(robot): sweep workspace, center cube, close gripper
- place_cube(robot): carry cube to a random position, release
Standalone usage (run from repo root):
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
To read current joint values for calibration, add after robot.connect():
obs = robot.get_observation()
print({k: round(obs[k], 1) for k in JOINT_NAMES})
robot.disconnect(); raise SystemExit
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
Linear interpolation preserves this constraint between any two poses that satisfy it.
"""
import argparse
import logging
import numpy as np
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
from lerobot.robots.robot import Robot
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
# ── Poses ─────────────────────────────────────────────────────────────────────
HOME_POSE = {
"shoulder_pan.pos": 0.0,
"shoulder_lift.pos": -50.0,
"elbow_flex.pos": 50.0,
"wrist_flex.pos": 0.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
}
SWEEP_WAYPOINTS = [
{
"shoulder_pan.pos": -60.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -60.0,
"wrist_flex.pos": -20.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{
"shoulder_pan.pos": -30.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -60.0,
"wrist_flex.pos": -5.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{
"shoulder_pan.pos": 20.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -55.0,
"wrist_flex.pos": -5.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
]
# ── Motion parameters ─────────────────────────────────────────────────────────
CONTROL_HZ = 30
APPROACH_SPEED = 50.0
SWEEP_SPEED = 40.0
# ── Grab-sequence parameters ──────────────────────────────────────────────────
GRAB_PAN = 0.0
SWEEP_LEFT_PAN = -60.0
SWEEP_RIGHT_PAN = 60.0
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
SWEEP_END_PAN_RANGE = (15.0, 20.0)
SWEEP_LOW_SHOULDER_LIFT = 50.0
SWEEP_LOW_ELBOW_FLEX_START = -60.0
SWEEP_LOW_ELBOW_FLEX_END = -55.0
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
PUSH_START_SHOULDER_LIFT = 0.0
PUSH_START_ELBOW_FLEX = 45.0
PUSH_END_SHOULDER_LIFT = 50.0
PUSH_END_ELBOW_FLEX = -50.0
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
# Does not affect the grab-target interpolation in record_grab.py.
PUSH_RAISE_OFFSET = 5.0
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
GRIPPER_CLOSE_POS = 50.0
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
JOINT_NAMES = [
"shoulder_pan.pos",
"shoulder_lift.pos",
"elbow_flex.pos",
"wrist_flex.pos",
"wrist_roll.pos",
"gripper.pos",
]
# ── Helpers ───────────────────────────────────────────────────────────────────
def pose_to_array(pose: dict) -> np.ndarray:
return np.array([pose[k] for k in JOINT_NAMES])
def array_to_pose(arr: np.ndarray) -> dict:
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
sl = SWEEP_LOW_SHOULDER_LIFT
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": elbow_flex,
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
}
def _high_sweep_pose(pan: float) -> dict:
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": shoulder_lift,
"elbow_flex.pos": elbow_flex,
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
"wrist_roll.pos": 0.0,
"gripper.pos": gripper,
}
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
"""Interpolate from current position to target at the given speed (units/s)."""
obs = robot.get_observation()
current = np.array([obs[k] for k in JOINT_NAMES])
goal = pose_to_array(target)
max_distance = float(np.max(np.abs(goal - current)))
if max_distance < 0.5:
return
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
dt = 1.0 / CONTROL_HZ
for step in range(1, n_steps + 1):
t = step / n_steps
robot.send_action(array_to_pose(current + t * (goal - current)))
precise_sleep(dt)
# ── Sequences ─────────────────────────────────────────────────────────────────
def grab_cube(robot: Robot) -> None:
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
for pan, end_pan in [
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
]:
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
move_to_pose(
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
)
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
logger.info("Extending to push cube into gripper...")
move_to_pose(
robot,
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
APPROACH_SPEED,
)
move_to_pose(
robot,
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
SWEEP_SPEED,
)
logger.info("Closing gripper...")
move_to_pose(
robot,
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
APPROACH_SPEED,
)
logger.info("Grab complete.")
def place_cube(robot: Robot) -> tuple[float, float]:
"""Carry the cube (gripper closed) to a random position on the left side, then release.
Returns:
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
"""
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
t = float(np.random.uniform(*PLACE_REACH_RANGE))
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
move_to_pose(
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
)
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
logger.info("Place complete.")
return pan, t
# ── Entry point ───────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
parser.add_argument("--port", default="/dev/ttyACM1")
parser.add_argument("--robot_id", default="omx_follower")
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
robot.connect(calibrate=True)
try:
if args.mode == "grab":
grab_cube(robot)
elif args.mode == "grab_and_place":
grab_cube(robot)
place_cube(robot)
finally:
robot.disconnect()
if __name__ == "__main__":
main()

View File

@@ -4,13 +4,13 @@ from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rewards.classifier.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -28,7 +28,7 @@ def run_learner(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_learner: SACPolicy,
policy_learner: GaussianActorPolicy,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer,
lr: float = 3e-4,
@@ -40,8 +40,9 @@ def run_learner(
policy_learner.train()
policy_learner.to(device)
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
algorithm.make_optimizers_and_scheduler()
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
@@ -83,24 +84,26 @@ def run_learner(
else:
batch[key] = online_batch[key]
loss, _ = policy_learner.forward(batch)
def batch_iter(b=batch):
while True:
yield b
optimizer.zero_grad()
loss.backward()
optimizer.step()
stats = algorithm.update(batch_iter())
training_step += 1
if training_step % LOG_EVERY == 0:
log_dict = stats.to_log_dict()
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"[LEARNER] Training step {training_step}, "
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
weights = algorithm.get_weights()
parameters_queue.put_nowait(weights)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
@@ -113,7 +116,7 @@ def run_actor(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_actor: SACPolicy,
policy_actor: GaussianActorPolicy,
reward_classifier: Classifier,
env_cfg: HILSerlRobotEnvConfig,
device: torch.device = "mps",
@@ -144,15 +147,15 @@ def run_actor(
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
new_weights = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_weights)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy
# Get action from policy (returns full action: continuous + discrete)
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action_tensor = policy_actor.select_action(policy_obs)
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
@@ -261,14 +264,14 @@ def main():
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create SAC policy for action selection
policy_cfg = SACConfig(
policy_cfg = GaussianActorConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
policy_actor = GaussianActorPolicy(policy_cfg)
policy_learner = GaussianActorPolicy(policy_cfg)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)

View File

@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Core ML
"torch>=2.7,<2.11.0",
"torchvision>=0.22.0,<0.26.0",
"torch>=2.7,<2.12.0",
"torchvision>=0.22.0,<0.27.0",
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"opencv-python-headless>=4.9.0,<4.14.0",
"Pillow>=10.0.0,<13.0.0",
@@ -99,7 +99,18 @@ dataset = [
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
# NOTE: torchcodec wheel availability matrix (PyPI):
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
"jsonlines>=4.0.0,<5.0.0",
]
training = [
@@ -195,7 +206,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -293,6 +304,20 @@ lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
# uv pip install --force-reinstall torch torchvision \
# --index-url https://download.pytorch.org/whl/cu130
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]

View File

@@ -430,7 +430,7 @@ class OpenCVCamera(Camera):
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
1. Reads a color frame (blocking call)
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
@@ -439,8 +439,9 @@ class OpenCVCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0
while not self.stop_event.is_set():
while not stop_event.is_set():
try:
raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame)
@@ -478,6 +479,8 @@ class OpenCVCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None

View File

@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
from the camera hardware via the RealSense pipeline.
Returns:
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
np.ndarray: The depth map as a NumPy array (height, width, 1)
of type `np.uint16` (raw depth values in millimeters).
Raises:
DeviceNotConnectedError: If the camera is not connected.
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame and updates timestamp (thread-safe)
1. Reads a color/depth frame (blocking call with 10s timeout)
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -474,8 +474,9 @@ class RealSenseCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0
while not self.stop_event.is_set():
while not stop_event.is_set():
try:
frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame()
@@ -486,6 +487,8 @@ class RealSenseCamera(Camera):
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
processed_depth_frame = processed_depth_frame[..., np.newaxis]
capture_time = time.perf_counter()
@@ -522,6 +525,8 @@ class RealSenseCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive(): # pragma: no cover
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None
@@ -532,7 +537,6 @@ class RealSenseCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
@@ -575,7 +579,6 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
@@ -611,6 +614,71 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
"""Read the latest depth frame asynchronously, in metric meters.
Mirrors :meth:`async_read` but returns the depth stream rather than the
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
the background read thread is not running.
TimeoutError: If no frame becomes available within ``timeout_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
self.new_frame_event.clear()
if depth_frame is None:
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
return depth_frame
@check_if_not_connected
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent depth frame in metric meters (peeking).
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
Output is ``np.uint16`` of shape ``(H, W, 1)`` in millimeters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
no depth frame has been captured yet.
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
timestamp = self.latest_timestamp
if depth_frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any depth frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return depth_frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.

View File

@@ -249,8 +249,9 @@ class ZMQCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.")
stop_event = self.stop_event
failure_count = 0
while not self.stop_event.is_set():
while not stop_event.is_set():
try:
frame = self._read_from_hardware()
capture_time = time.perf_counter()
@@ -292,6 +293,8 @@ class ZMQCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None

View File

@@ -99,6 +99,7 @@ def save_checkpoint(
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)

View File

@@ -31,6 +31,14 @@ from .types import (
PolicyFeature,
RTCAttentionSchedule,
)
from .video import (
VALID_VIDEO_CODECS,
VIDEO_ENCODER_INFO_KEYS,
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
__all__ = [
# Types
@@ -46,4 +54,12 @@ __all__ = [
"PeftConfig",
"PreTrainedConfig",
"WandBConfig",
"VideoEncoderConfig",
"DepthEncoderConfig",
# Defaults
"camera_encoder_defaults",
"depth_encoder_defaults",
# Constants
"VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS",
]

View File

@@ -14,10 +14,12 @@
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
@dataclass
class DatasetRecordConfig:
@@ -55,10 +57,11 @@ class DatasetRecordConfig:
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False

View File

@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from lerobot.transforms import ImageTransformsConfig
from lerobot.utils.import_utils import get_safe_default_codec
from lerobot.utils.import_utils import get_safe_default_video_backend
@dataclass
@@ -34,7 +34,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
video_backend: str = field(default_factory=get_safe_default_video_backend)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
@@ -117,3 +117,9 @@ class PeftConfig:
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
# fine-tuning.
r: int = 16
# Alpha parameter for LoRA scaling (scaling = lora_alpha / r).
# In general, a higher alpha means stronger adaptation signal.
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
# Common values are r (alpha == rank) or 2*r.
lora_alpha: int | None = None

View File

@@ -18,8 +18,8 @@ from logging import getLogger
from pathlib import Path
from lerobot import envs, policies # noqa: F401
from lerobot.configs import parser
from . import parser
from .default import EvalConfig
from .policies import PreTrainedConfig
@@ -46,8 +46,11 @@ class EvalPipelineConfig:
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
yaml_overrides = parser.get_yaml_overrides("policy")
cli_overrides = parser.get_cli_overrides("policy") or []
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=yaml_overrides + cli_overrides
)
self.policy.pretrained_path = Path(policy_path)
else:

View File

@@ -13,8 +13,10 @@
# limitations under the License.
import importlib
import inspect
import json
import pkgutil
import sys
import tempfile
from argparse import ArgumentError
from collections.abc import Callable, Iterable, Sequence
from functools import wraps
@@ -24,6 +26,7 @@ from types import ModuleType
from typing import Any, TypeVar, cast
import draccus
import yaml # type: ignore[import-untyped]
from lerobot.utils.utils import has_method
@@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object])
PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
# Storage for path args extracted from YAML/JSON config files, so that
# get_path_arg() can find them even when they weren't passed via CLI.
_config_path_args: dict[str, str] = {}
# Storage for non-path YAML overrides so validate() can pass them to from_pretrained.
_config_yaml_overrides: dict[str, list[str]] = {}
def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]:
"""Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"])."""
args = []
for key, value in d.items():
if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY):
continue
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, bool):
value = str(value).lower()
if isinstance(value, dict):
args.extend(_flatten_to_cli_args(value, full_key))
elif value is not None and not isinstance(value, list):
args.append(f"--{full_key}={value}")
return args
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level.
@@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None:
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{PATH_KEY}", args)
result = parse_arg(f"{field_name}.{PATH_KEY}", args)
if result is None:
result = _config_path_args.get(field_name)
return result
def get_yaml_overrides(field_name: str) -> list[str]:
return _config_yaml_overrides.get(field_name, [])
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
@@ -192,6 +225,52 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args
def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str:
"""Extract `path` fields from a YAML/JSON config before draccus processes it.
When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config,
draccus will fail because ``path`` is not a valid field on policy config classes.
This function extracts those path values, stores them in ``_config_path_args`` for
later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path.
"""
config_file = Path(config_path)
suffix = config_file.suffix.lower()
if suffix in (".yaml", ".yml"):
with open(config_file) as f:
config_data = yaml.safe_load(f)
elif suffix == ".json":
with open(config_file) as f:
config_data = json.load(f)
else:
return config_path
if not isinstance(config_data, dict):
return config_path
modified = False
for field in path_fields:
if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]:
_config_path_args[field] = str(config_data[field].pop(PATH_KEY))
remaining = config_data[field]
if remaining:
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
else:
del config_data[field]
modified = True
if not modified:
return config_path
# Write cleaned config to a temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
if suffix in (".yaml", ".yml"):
yaml.dump(config_data, tmp, default_flow_style=False)
else:
json.dump(config_data, tmp, indent=2)
return tmp.name
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
"""
HACK: Similar to draccus.wrap but does three additional things:
@@ -225,6 +304,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
if has_method(argtype, "__get_path_fields__"):
path_fields = argtype.__get_path_fields__()
cli_args = filter_path_args(path_fields, cli_args)
# Also extract path fields from the YAML/JSON config file
if config_path_cli:
config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields)
if has_method(argtype, "from_pretrained") and config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)

View File

@@ -27,12 +27,13 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from lerobot.utils.hub import HubMixin
from .types import PolicyFeature
T = TypeVar("T", bound="RewardModelConfig")
logger = logging.getLogger(__name__)
@@ -89,9 +90,9 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
def get_optimizer_preset(self) -> OptimizerConfig | None:
"""Default optimizer for this reward model, or ``None`` for zero-shot models."""
return None
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None

View File

@@ -25,11 +25,11 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from . import parser
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
@@ -144,8 +144,11 @@ class TrainPipelineConfig(HubMixin):
)
self.reward_model.pretrained_path = str(Path(reward_model_path))
elif policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
yaml_overrides = parser.get_yaml_overrides("policy")
cli_overrides = parser.get_cli_overrides("policy") or []
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=yaml_overrides + cli_overrides
)
self.policy.pretrained_path = Path(policy_path)
elif self.resume:
config_path = parser.parse_arg("config_path")
@@ -256,7 +259,9 @@ class TrainPipelineConfig(HubMixin):
) from e
cli_args = kwargs.pop("cli_args", [])
if config_file is not None:
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
if config_file is not None and config_file.endswith(".json"):
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
@@ -267,10 +272,3 @@ class TrainPipelineConfig(HubMixin):
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args)
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional

View File

@@ -0,0 +1,315 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: We subclass str so that serialization is straightforward
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
"""Video encoder configurations."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar
from lerobot.utils.import_utils import require_package
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and the chosen video backend.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_VIDEO_CODECS = [
"h264_videotoolbox", # macOS
"hevc_videotoolbox", # macOS
"h264_nvenc", # NVIDIA GPU
"hevc_nvenc", # NVIDIA GPU
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
)
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
LIBSVTAV1_DEFAULT_PRESET: int = 12
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
# ``vcodec``` and ``pix_fmt`` are derived from the video stream directly.
VIDEO_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset(
{"g", "crf", "preset", "fast_decode", "extra_options", "video_backend"}
)
VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
)
# Default depth quantization and encoding parameters.
DEPTH_QUANT_BITS: int = 12
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
DEFAULT_DEPTH_MIN: float = 0.01
DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
@dataclass
class VideoEncoderConfig:
"""Video encoder configuration.
Attributes:
vcodec: Video encoder name. ``"auto"`` is resolved during
construction (HW encoder if available, else ``libsvtav1``).
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
g: GOP size (keyframe interval).
crf: Quality level — mapped to the native quality parameter of the
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
preset: Speed/quality preset. Accepted type is per-codec.
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
set ``tune=fastdecode``. Ignored for other codecs.
video_backend: Python to be used for encoding. Only ``"pyav"``
is currently supported.
extra_options: Free-form dictionary of additional video encoder options
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
"""
vcodec: str = "libsvtav1" # TODO(CarolinePascal): rename to codec ?
pix_fmt: str = "yuv420p"
g: int | None = 2
crf: int | float | None = 30
preset: int | str | None = None
fast_decode: int = 0
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
# two backends (encoding and decoding).
video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict)
# Source-data channel count this encoder is expected to handle (3 for RGB,
# 1 for depth, etc.)
_DEFAULT_CHANNELS: ClassVar[int] = 3
def __post_init__(self) -> None:
self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
if self.preset is None and self.vcodec == "libsvtav1":
self.preset = LIBSVTAV1_DEFAULT_PRESET
self.validate()
@classmethod
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
Missing or ``None`` values fall back to the class defaults.
"""
video_info = video_info or {}
kwargs: dict[str, Any] = {}
for src_key, dst_field in (("video.codec", "vcodec"), ("video.pix_fmt", "pix_fmt")):
value = video_info.get(src_key)
if value is not None:
kwargs[dst_field] = value
for field_name in VIDEO_ENCODER_INFO_FIELD_NAMES:
value = video_info.get(f"video.{field_name}")
if value is None:
continue
# Persisted as ``{}`` after merges with disagreeing sources — treat as default.
if field_name == "extra_options" and not value:
continue
kwargs[field_name] = value
return cls(**kwargs)
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
"""Return the subset of available encoders based on the specified video backend.
Args:
encoders: List of encoder names to detect. If a string, it is converted to a list.
Returns:
List of available encoder names. If the video backend is not "pyav", returns an empty list.
"""
if self.video_backend == "pyav":
require_package("av", extra="dataset")
from lerobot.datasets import detect_available_encoders_pyav
return detect_available_encoders_pyav(encoders)
return []
def validate(self) -> None:
"""Validate the video encoder configuration."""
if self.video_backend == "pyav":
require_package("av", extra="dataset")
from lerobot.datasets import check_video_encoder_parameters_pyav
check_video_encoder_parameters_pyav(
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
)
def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
For ``"auto"``, the first hardware encoder in the preference list that is available is chosen; if none are available, ``libsvtav1`` is used. If the
resolved codec (explicit or after auto-selection) is not available, raises ``ValueError``.
Stream-derived canonical codec names listed in :data:`VIDEO_CODECS_ALIASES` are
rewritten to their corresponding encoder name (e.g. ``"av1"`` → ``"libsvtav1"``).
"""
self.vcodec = VIDEO_CODECS_ALIASES.get(self.vcodec, self.vcodec)
if self.vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if self.vcodec == "auto":
available = self.detect_available_encoders(HW_VIDEO_CODECS)
for encoder in HW_VIDEO_CODECS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
self.vcodec = encoder
return
logger.warning("No hardware encoder available, falling back to software encoder 'libsvtav1'")
self.vcodec = "libsvtav1"
if self.detect_available_encoders(self.vcodec):
logger.info(f"Using video codec: {self.vcodec}")
return
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
def get_codec_options(
self, encoder_threads: int | None = None, as_strings: bool = False
) -> dict[str, Any]:
"""Translate the tuning fields to codec-specific options.
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
Args:
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
For h264/hevc, this is mapped to ``threads``.
Hardware encoders ignore this parameter.
as_strings: If ``True``, casts values to strings.
"""
opts: dict[str, Any] = {}
def set_if(key: str, value: Any) -> None:
if value is not None:
opts[key] = value if not as_strings else str(value)
# GOP size is not a codec-specific option, so it is always set.
set_if("g", self.g)
if self.vcodec == "libsvtav1":
set_if("crf", self.crf)
set_if("preset", self.preset)
svtav1_parts: list[str] = []
if self.fast_decode is not None:
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
if encoder_threads is not None:
svtav1_parts.append(f"lp={encoder_threads}")
if svtav1_parts:
opts["svtav1-params"] = ":".join(svtav1_parts)
elif self.vcodec in ("h264", "hevc"):
set_if("crf", self.crf)
set_if("preset", self.preset)
if self.fast_decode:
opts["tune"] = "fastdecode"
set_if("threads", encoder_threads)
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
if self.crf is not None:
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
opts["rc"] = 0
set_if("qp", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "h264_vaapi":
set_if("qp", self.crf)
elif self.vcodec == "h264_qsv":
set_if("global_quality", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "ffv1":
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
# are not meaningful.
set_if("threads", encoder_threads)
else:
set_if("crf", self.crf)
set_if("preset", self.preset)
# Extra options are merged last but never override structured fields (values are kept as given).
for k, v in self.extra_options.items():
if k not in opts:
set_if(k, v)
return opts
def camera_encoder_defaults() -> VideoEncoderConfig:
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
return VideoEncoderConfig()
@dataclass
class DepthEncoderConfig(VideoEncoderConfig):
"""Encoder configuration for depth-map streams.
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
preset, ``extra_options``…) and adds the four parameters of the depth
quantizer.
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
to ``"gray12le"``.
Attributes:
depth_min: Minimum depth in physical units (e.g. metres) represented
by quantum ``0``.
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
shift: Pre-log offset for numerical stability near zero.
use_log: ``True`` for logarithmic quantization (default; matches
sensor error profile), ``False`` for linear.
"""
vcodec: str = "hevc"
pix_fmt: str = "gray12le"
depth_min: float = DEFAULT_DEPTH_MIN
depth_max: float = DEFAULT_DEPTH_MAX
shift: float = DEFAULT_DEPTH_SHIFT
use_log: bool = DEFAULT_DEPTH_USE_LOG
_DEFAULT_CHANNELS: ClassVar[int] = 1
@classmethod
def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig:
"""Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block.
Reuses :meth:`VideoEncoderConfig.from_video_info` for the base
codec/tuning fields and then layers the depth-specific tuning
(``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top.
Missing keys fall back to the class defaults.
"""
base = VideoEncoderConfig.from_video_info(video_info)
kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init}
video_info = video_info or {}
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
value = video_info.get(f"video.{name}")
if value is not None:
kwargs[name] = value
return cls(**kwargs)
def depth_encoder_defaults() -> DepthEncoderConfig:
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
return DepthEncoderConfig()

View File

@@ -40,6 +40,7 @@ from .io_utils import load_episodes, write_stats
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
@@ -59,6 +60,8 @@ __all__ = [
"MultiLeRobotDataset",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"check_video_encoder_parameters_pyav",
"detect_available_encoders_pyav",
"add_features",
"aggregate_datasets",
"aggregate_pipeline_dataset_features",

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import shutil
from pathlib import Path
@@ -23,9 +24,11 @@ import datasets
import pandas as pd
import tqdm
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from .compute_stats import aggregate_stats
from .dataset_metadata import LeRobotDatasetMetadata
from .feature_utils import get_hf_features_from_features
from .feature_utils import features_equal_for_merge, get_hf_features_from_features
from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
@@ -46,11 +49,54 @@ from .utils import (
from .video_utils import concatenate_video_files, get_video_duration_in_s
def merge_video_feature_info_for_aggregate(all_metadata: list[LeRobotDatasetMetadata]) -> dict[str, dict]:
"""Create a merged video feature info dictionary for aggregation. The video encoder info is merged field-by-field: each key is kept only when every source agrees; otherwise that key is set to ``null`` (or ``{}`` for ``video.extra_options``) and a warning is logged.
Args:
all_metadata: List of LeRobotDatasetMetadata objects to merge.
Returns:
dict: A dictionary of merged video feature info.
"""
merged_info = copy.deepcopy(all_metadata[0].features)
video_keys = [k for k in merged_info if merged_info[k].get("dtype") == "video"]
for vk in video_keys:
video_infos = [m.features.get(vk, {}).get("info") or {} for m in all_metadata]
base_video_info = video_infos[0]
merged_encoder_info: dict = {}
fallback_keys: list[str] = []
for info_key in VIDEO_ENCODER_INFO_KEYS:
values = [info.get(info_key, None) for info in video_infos]
first_value = values[0]
all_match = all(v == first_value for v in values[1:])
if all_match:
merged_encoder_info[info_key] = first_value
else:
fallback_keys.append(info_key)
merged_encoder_info[info_key] = {} if info_key == "video.extra_options" else None
if fallback_keys:
logging.warning(
f"Merging heterogeneous or incomplete video encoder metadata for feature {vk}. "
f"Setting these keys to null: {fallback_keys}.",
)
merged_info[vk]["info"] = {**base_video_info, **merged_encoder_info}
# TODO(CarolinePascal): make this variable once we have support for other video backends.
merged_info[vk]["info"]["video.video_backend"] = "pyav"
return merged_info
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
"""Validates that all dataset metadata have consistent properties.
Ensures all datasets have the same fps, robot_type, and features to guarantee
compatibility when aggregating them into a single dataset.
Video encoder info is not considered for validation but is merged during aggregation in ``merge_video_feature_info_for_aggregate``.
Args:
all_metadata: List of LeRobotDatasetMetadata objects to validate.
@@ -74,7 +120,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
raise ValueError(
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
)
if features != meta.features:
if not features_equal_for_merge(features, meta.features):
raise ValueError(
f"Same features is expected, but got features={meta.features} instead of {features}."
)
@@ -274,7 +320,8 @@ def aggregate_datasets(
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
fps, robot_type, features = validate_all_metadata(all_metadata)
fps, robot_type, _ = validate_all_metadata(all_metadata)
features = merge_video_feature_info_for_aggregate(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
dst_meta = LeRobotDatasetMetadata.create(
@@ -332,7 +379,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
@@ -414,9 +460,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
current_dst_duration = dst_file_durations.get(dst_key, 0)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
# TODO(CarolinePascal): Move the check before the loop to avoid failing in the middle + add possibility to re-encode the video if the check fails
concatenate_video_files(
[dst_path, src_path],
dst_path,
compatibility_check=True,
)
# Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration

View File

@@ -550,8 +550,10 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
if key == "count" and value.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
raise ValueError(
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
)
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Callable
from pathlib import Path
import numpy as np
@@ -23,6 +24,7 @@ import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.configs import VideoEncoderConfig
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
from lerobot.utils.feature_utils import _validate_feature_names
from lerobot.utils.utils import flatten_dict
@@ -189,6 +191,29 @@ class LeRobotDatasetMetadata:
if self.episodes is None:
self._load_metadata()
def filter_episodes(
self,
predicate: Callable[[dict], bool],
candidates: list[int] | None = None,
) -> list[int]:
"""Filter episodes whose metadata satisfies a given predicate.
Args:
predicate: Predicate over per-episode metadata rows used to select episodes.
candidates: Optional list of episode indices to restrict evaluation to.
Returns:
List of sorted episode indices that satisfy the predicate.
"""
self.ensure_readable()
if candidates is not None:
candidate_set = set(candidates)
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
else:
combined = predicate
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
return sorted(int(idx) for idx in filtered["episode_index"])
def _pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
@@ -313,6 +338,25 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def depth_keys(self) -> list[str]:
"""Keys to access depth-map modalities stored as videos or images.
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
"""
def _is_depth(ft: dict) -> bool:
info = ft.get("info") or {}
video_info = ft.get("video_info") or {}
return (
info.get("is_depth_map", False)
or info.get("video.is_depth_map", False)
or video_info.get("video.is_depth_map", False)
)
return [key for key, ft in self.features.items() if _is_depth(ft)]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
@@ -510,19 +554,36 @@ class LeRobotDatasetMetadata:
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
def update_video_info(self, video_key: str | None = None) -> None:
"""
def update_video_info(
self,
video_key: str | None = None,
video_encoder: VideoEncoderConfig | None = None,
) -> None:
"""Populate per-feature video info in ``info.json``.
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
Args:
video_key: If provided, only update this video key. Otherwise update
all video keys in the dataset.
camera_encoder: Encoder configuration used to produce the
videos. When provided, its fields are recorded as
``video.<field>`` entries alongside the stream-derived
``video.*`` entries (see :func:`get_video_info`).
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info.features[key]["info"] = get_video_info(video_path)
existing = self.features[key].get("info") or {}
# Skip only if real video info has already been written. The ``is_depth_map`` entry (created at feature creation) is not blocking.
if set(existing.keys()) - {"is_depth_map"}:
continue
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
new_info = get_video_info(video_path, video_encoder=video_encoder)
self.info.features[key]["info"] = {**existing, **new_info}
def update_chunk_settings(
self,

View File

@@ -22,7 +22,10 @@ from pathlib import Path
import datasets
import torch
from lerobot.configs.video import DepthEncoderConfig
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import (
check_delta_timestamps,
get_delta_indices,
@@ -86,6 +89,12 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
@@ -247,7 +256,18 @@ class DatasetReader:
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
is_depth=vid_key in self._meta.depth_keys,
)
if vid_key in self._meta.depth_keys:
depth_encoder = self._depth_encoder_configs[vid_key]
frames = dequantize_depth(
frames,
depth_min=depth_encoder.depth_min,
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_tensor=True,
)
return vid_key, frames.squeeze(0)
items = list(query_timestamps.items())

View File

@@ -36,6 +36,7 @@ import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
from lerobot.utils.utils import flatten_dict
@@ -62,7 +63,10 @@ from .utils import (
DEFAULT_EPISODES_PATH,
update_chunk_file_indices,
)
from .video_utils import encode_video_frames, get_video_info
from .video_utils import (
encode_video_frames,
get_video_info,
)
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -95,6 +99,11 @@ def delete_episodes(
) -> LeRobotDataset:
"""Delete episodes from a LeRobotDataset and create a new dataset.
Video segments that need re-encoding (because the source file mixes kept and
deleted episodes) are re-encoded with the source dataset's existing encoder
settings — read back from ``meta/info.json`` — so the output dataset stays
consistent with its own metadata.
Args:
dataset: The source LeRobotDataset.
episode_indices: List of episode indices to delete.
@@ -157,6 +166,11 @@ def split_dataset(
) -> dict[str, LeRobotDataset]:
"""Split a LeRobotDataset into multiple smaller datasets.
Video segments that need re-encoding (because the source file mixes episodes
that fall into different splits) are re-encoded with the source dataset's
existing encoder settings — read back from ``meta/info.json`` — so each
output split stays consistent with its own metadata.
Args:
dataset: The source LeRobotDataset to split.
splits: Either a dict mapping split names to episode indices, or a dict mapping
@@ -578,8 +592,7 @@ def _keep_episodes_from_video_with_av(
output_path: Path,
episodes_to_keep: list[tuple[int, int]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
camera_encoder: VideoEncoderConfig,
) -> None:
"""Keep only specified episodes from a video file using PyAV.
@@ -593,8 +606,7 @@ def _keep_episodes_from_video_with_av(
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
camera_encoder: Video encoder settings used to re-encode the kept frames.
"""
from fractions import Fraction
@@ -619,12 +631,13 @@ def _keep_episodes_from_video_with_av(
# Convert fps to Fraction for PyAV compatibility.
fps_fraction = Fraction(fps).limit_denominator(1000)
v_out = out.add_stream(vcodec, rate=fps_fraction)
codec_options = camera_encoder.get_codec_options(as_strings=True)
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
v_out.width = v_in.codec_context.width
v_out.height = v_in.codec_context.height
v_out.pix_fmt = pix_fmt
v_out.pix_fmt = camera_encoder.pix_fmt
# Set time_base to match the frame rate for proper timestamp handling.
v_out.time_base = Fraction(1, int(fps))
@@ -687,14 +700,14 @@ def _copy_and_reindex_videos(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_mapping: dict[int, int],
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> dict[int, dict]:
"""Copy and filter video files, only re-encoding files with deleted episodes.
For video files that only contain kept episodes, we copy them directly.
For files with mixed kept/deleted episodes, we use PyAV filters to efficiently
re-encode only the desired segments.
re-encode only the desired segments. The encoder used for re-encoding is
derived per video key from the source dataset's ``meta/info.json`` so the
destination metadata keeps describing the videos accurately.
Args:
src_dataset: Source dataset to copy from
@@ -711,6 +724,9 @@ def _copy_and_reindex_videos(
for video_key in src_dataset.meta.video_keys:
logging.info(f"Processing videos for {video_key}")
camera_encoder = VideoEncoderConfig.from_video_info(
src_dataset.meta.info.features.get(video_key, {}).get("info")
)
if dst_meta.video_path is None:
raise ValueError("Destination metadata has no video_path defined")
@@ -792,8 +808,7 @@ def _copy_and_reindex_videos(
dst_video_path,
episodes_to_keep_ranges,
src_dataset.meta.fps,
vcodec,
pix_fmt,
camera_encoder,
)
cumulative_ts = 0.0
@@ -1264,11 +1279,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: list[int],
temp_dir: Path,
fps: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
camera_encoder: VideoEncoderConfig,
num_calibration_frames: int = 30,
) -> float:
"""Estimate MB per frame by encoding a small calibration sample.
@@ -1282,11 +1293,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding.
vcodec: Video codec (libsvtav1, h264, hevc).
pix_fmt: Pixel format (yuv420p, etc.).
g: GOP size (group of pictures).
crf: Constant Rate Factor (quality).
fast_decode: Fast decode tuning parameter.
camera_encoder: Video encoder settings used for calibration encoding.
num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns:
@@ -1322,11 +1329,7 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
video_encoder=camera_encoder,
overwrite=True,
)
@@ -1644,11 +1647,7 @@ def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path | None = None,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
camera_encoder: VideoEncoderConfig | None = None,
episode_indices: list[int] | None = None,
num_workers: int = 4,
max_episodes_per_batch: int | None = None,
@@ -1663,11 +1662,8 @@ def convert_image_to_video_dataset(
dataset: The source LeRobot dataset with images
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
camera_encoder: Video encoder settings
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
@@ -1676,6 +1672,9 @@ def convert_image_to_video_dataset(
Returns:
New LeRobotDataset with images encoded as videos
"""
if camera_encoder is None:
camera_encoder = camera_encoder_defaults()
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
@@ -1699,7 +1698,10 @@ def convert_image_to_video_dataset(
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
logging.info(
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
)
# Create new features dict, converting image features to video features
new_features = {}
@@ -1769,11 +1771,7 @@ def convert_image_to_video_dataset(
episode_indices=episode_indices,
temp_dir=temp_dir,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
camera_encoder=camera_encoder,
)
logging.info(f"Processing camera: {img_key}")
@@ -1815,11 +1813,7 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
video_encoder=camera_encoder,
overwrite=True,
)
@@ -1865,7 +1859,9 @@ def convert_image_to_video_dataset(
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
new_meta.info.features[img_key]["info"] = get_video_info(
video_path, video_encoder=camera_encoder
)
write_info(new_meta.info, new_meta.root)

View File

@@ -31,6 +31,13 @@ import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from .compute_stats import compute_episode_stats
from .dataset_metadata import LeRobotDatasetMetadata
from .feature_utils import (
@@ -46,6 +53,7 @@ from .io_utils import (
write_info,
)
from .utils import (
DEFAULT_DEPTH_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
update_chunk_file_indices,
@@ -65,14 +73,24 @@ def _encode_video_worker(
episode_index: int,
root: Path,
fps: int,
vcodec: str = "libsvtav1",
video_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
path_template = (
DEFAULT_DEPTH_PATH
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
else DEFAULT_IMAGE_PATH
)
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
img_dir,
temp_path,
fps,
video_encoder=video_encoder,
encoder_threads=encoder_threads,
overwrite=True,
)
shutil.rmtree(img_dir)
return temp_path
@@ -89,20 +107,25 @@ class DatasetWriter:
self,
meta: LeRobotDatasetMetadata,
root: Path,
vcodec: str,
camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_threads: int | None,
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
initial_frames: int = 0,
):
"""Initialize the writer with metadata, codec, and encoding config.
"""Initialize the writer with metadata, codec, and encoder config.
Args:
meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence).
root: Local dataset root directory.
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
encoder_threads: Threads per encoder instance. ``None`` for auto.
camera_encoder: Video encoder settings applied to all cameras.
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
depth_encoder: Video encoder settings applied to all **depth** cameras.
``None`` uses :func:`~lerobot.configs.depth_encoder_defaults`.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
@@ -111,7 +134,8 @@ class DatasetWriter:
"""
self._meta = meta
self._root = root
self._vcodec = vcodec
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
@@ -136,7 +160,8 @@ class DatasetWriter:
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
fpath = path_template.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self._root / fpath
@@ -186,6 +211,7 @@ class DatasetWriter:
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
video_keys=list(self._meta.video_keys),
depth_video_keys=set(self._meta.video_keys) & set(self._meta.depth_keys),
temp_dir=self._root,
)
@@ -284,7 +310,9 @@ class DatasetWriter:
episode_index,
self._root,
self._meta.fps,
self._vcodec,
self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
self._encoder_threads,
): video_key
for video_key in self._meta.video_keys
@@ -495,7 +523,12 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(video_key)
self._meta.update_video_info(
video_key,
video_encoder=self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
)
write_info(self._meta.info, self._meta.root)
metadata = {
@@ -562,9 +595,15 @@ class DatasetWriter:
self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
is_depth = video_key in self._meta.depth_keys
return _encode_video_worker(
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
video_key,
episode_index,
self._root,
self._meta.fps,
self._depth_encoder if is_depth else self._camera_encoder,
self._encoder_threads,
)
def close_writer(self) -> None:

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
"""
import math
from typing import Literal
import av
import numpy as np
import torch
from numpy.typing import NDArray
from lerobot.configs.video import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_PIX_FMT,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
DEPTH_QMAX,
)
from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0
_UINT16_MAX = 65535
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite."""
if depth_min + shift <= 0:
raise ValueError(
f"depth_min + shift must be positive for logarithmic quantization, "
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
)
def _depth_input_to_float32_and_unit(
depth: NDArray[np.integer] | NDArray[np.floating],
input_unit: Literal["auto", "m", "mm"],
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
resolved_unit = (
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
)
return depth.astype(np.float32, order="K"), resolved_unit
def quantize_depth(
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16] | av.VideoFrame:
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
Depth maps are packed into 12-bit integer frames so they fit in standard
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
Logarithmic quantization is the default because it allocates more quanta
to near-range depth, which matches the (1/depth) error profile of typical
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
**Input units**:
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
- ``input_unit="mm"``: interpret input values as millimetres.
- ``input_unit="m"``: interpret input values as metres.
Quantization math runs in the **resolved input unit**.
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
Args:
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
depth_min: Depth (metres) at quantum ``0``.
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
use_log: If ``True`` (default), quantize in log space.
video_backend: Video backend to use for encoding. Defaults to "pyav".
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
Returns:
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
``[0, DEPTH_QMAX]``.
Raises:
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if input_unit not in ("auto", "m", "mm"):
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
depth = depth.squeeze()
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
# Convert depth_min, depth_max, and shift to the resolved input unit.
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
# Normalization and quantization is performed in the resolved input unit.
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_u + shift_u))
log_max = math.log(float(depth_max_u + shift_u))
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
else:
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
if video_backend == "pyav":
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
write_u16_plane(frame.planes[0], quantized)
return frame
else:
return quantized
def dequantize_depth(
quantized: NDArray[np.uint16] | av.VideoFrame,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm",
output_tensor: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
"""Inverse of :func:`quantize_depth`.
Tuning arguments **must match** :func:`quantize_depth`.
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
the requested output unit.
Args:
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
``[depth_min, depth_max]``.
output_tensor: If True, return a torch.Tensor instead of a numpy array.
Returns:
Depth map in the requested unit and dtype.
Raises:
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
"""
if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if isinstance(quantized, av.VideoFrame):
quantized = quantized.to_ndarray(format=pix_fmt)
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
depth_min_m = np.float32(depth_min)
depth_max_m = np.float32(depth_max)
shift_m = np.float32(shift)
# The de-normalization and de-quantization is performed in meters (convenience choice).
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_m + shift_m))
log_max = math.log(float(depth_max_m + shift_m))
depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m
else:
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
# Add single-channel dim: (H, W) → (H, W, 1)
if depth_m.ndim == 2:
depth_m = depth_m[..., np.newaxis]
# Return depth as float32 meters.
if output_unit == "m":
return torch.from_numpy(depth_m) if output_tensor else depth_m
# Return depth as uint16 millimeters.
mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False)
if output_tensor:
# torch.uint16 support is very limited, we convert to float32 instead.
return torch.from_numpy(mm.astype(np.float32))
else:
return mm

View File

@@ -19,6 +19,7 @@ import datasets
import numpy as np
from PIL import Image as PILImage
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
@@ -108,6 +109,41 @@ def create_empty_dataset_info(
)
def features_equal_for_merge(features_a: dict[str, dict], features_b: dict[str, dict]) -> bool:
"""Return whether two LeRobotDatasetMetadata ``features`` dicts are compatible for aggregation.
For video features, keys under ``info`` related to video encoding parameters are ignored during
comparison as they do not prevent aggregation.
"""
def _without_encoder_info_keys(feature: dict) -> dict:
filtered = dict(feature)
filtered_info = filtered.get("info")
if isinstance(filtered_info, dict):
filtered["info"] = {
info_key: info_value
for info_key, info_value in filtered_info.items()
if info_key not in VIDEO_ENCODER_INFO_KEYS
}
return filtered
if set(features_a) != set(features_b):
return False
for key in features_a:
fa_key = features_a[key]
fb_key = features_b[key]
if fa_key.get("dtype") != fb_key.get("dtype"):
return False
if fa_key.get("dtype") != "video":
if fa_key != fb_key:
return False
continue
if _without_encoder_info_keys(fa_key) != _without_encoder_info_keys(fb_key):
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
@@ -285,7 +321,7 @@ def validate_feature_image_or_video(
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C, H, W).
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
value: The image data to validate.
Returns:

View File

@@ -42,10 +42,41 @@ def safe_stop_image_writer(func):
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
Behaviour by shape:
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
The native dtype is preserved using the matching PIL mode
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
(existing behaviour, gated by ``range_check``).
Other shapes / channel counts raise ``NotImplementedError`` or
``ValueError``.
"""
# TODO(CarolinePascal): 4 dimensions RGB-D images
if image_array.ndim not in (2, 3):
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
# caller emits (H, W), (1, H, W), or (H, W, 1).
if image_array.ndim == 3:
if image_array.shape[0] == 1:
image_array = image_array[0]
elif image_array.shape[-1] == 1:
image_array = image_array[..., 0]
if image_array.ndim == 2:
if image_array.dtype not in [np.uint16, np.float32]:
raise ValueError(
f"Unsupported single-channel image dtype: {image_array.dtype}. "
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
)
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
# 3D path: must be RGB (3 channels), channels-first or channels-last.
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
@@ -71,13 +102,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
return PIL.Image.fromarray(image_array)
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
"""
suffix = Path(fpath).suffix.lower()
if suffix == ".png":
return {"compress_level": compress_level}
if suffix in (".tif", ".tiff"):
return {"compression": "raw"}
return {}
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
"""
Saves a NumPy array or PIL Image to a file.
This function handles both NumPy arrays and PIL Image objects, converting
the former to a PIL Image before saving. It includes error handling for
the save operation.
the save operation. The output format is inferred from the *fpath*
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
→ lossless raw depth maps (TIFF).
Args:
image (np.ndarray | PIL.Image.Image): The image data to save.
@@ -101,7 +147,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, compress_level=compress_level)
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
except Exception as e:
logger.error("Error writing image %s: %s", fpath, e)

View File

@@ -24,6 +24,7 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
@@ -36,8 +37,7 @@ from .utils import (
)
from .video_utils import (
StreamingVideoEncoder,
get_safe_default_codec,
resolve_vcodec,
get_safe_default_video_backend,
)
logger = logging.getLogger(__name__)
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
episode_filter: Callable[[dict], bool] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
@@ -58,10 +59,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
return_uint8: bool = False,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -153,6 +155,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
Defaults to None.
image_transforms (Callable | None, optional):
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
conversion. This works for both image-backed and video-backed observations and can later be
@@ -177,16 +184,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
is used by the writer.
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults`
is used by the writer.
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
codec decide.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
streaming encoding. Defaults to 30 (~1s at 30fps).
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
libsvtav1 and 'threads' for h264/hevc.
Note:
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
@@ -199,13 +208,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec()
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
self._return_uint8 = return_uint8
self._batch_encoding_size = batch_encoding_size
self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
if self._requested_root is not None:
@@ -218,6 +225,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root = self.meta.root
self.revision = self.meta.revision
if episodes is not None and any(
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
):
logger.warning(
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
)
if episode_filter is not None:
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
if not resolved:
raise ValueError(
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
)
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
episodes = resolved
self.episodes = episodes
# Create reader (hf_dataset loaded below)
self.reader = DatasetReader(
meta=self.meta,
@@ -251,12 +275,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(self.meta.video_keys) > 0:
streaming_enc = self._build_streaming_encoder(
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
self.meta.fps,
camera_encoder,
depth_encoder,
encoder_queue_maxsize,
encoder_threads,
)
self.writer = DatasetWriter(
meta=self.meta,
root=self.root,
vcodec=self._vcodec,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -298,17 +327,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
@staticmethod
def _build_streaming_encoder(
fps: int,
vcodec: str,
camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_queue_maxsize: int,
encoder_threads: int | None,
) -> StreamingVideoEncoder:
return StreamingVideoEncoder(
fps=fps,
vcodec=vcodec,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
@@ -625,7 +652,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0,
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -656,20 +684,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend (used when reading back).
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos. ``1`` means encode immediately.
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
``'h264'``, ``'hevc'``, ``'auto'``.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
metadata_buffer_size: Number of episode metadata records to buffer
before flushing to parquet.
streaming_encoding: If ``True``, encode video frames in real-time
during capture instead of writing images first.
encoder_queue_maxsize: Max buffered frames per camera when using
streaming encoding.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns:
A new :class:`LeRobotDataset` in write mode.
"""
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -690,23 +720,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
# Reader is lazily created on first access (write-only mode)
obj.reader = None
# Create writer
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
streaming_enc = cls._build_streaming_encoder(
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
vcodec=vcodec,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -729,12 +760,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
) -> "LeRobotDataset":
"""Resume recording on an existing dataset.
@@ -757,13 +789,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend for reading back data.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
vcodec: Video codec for encoding.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
image_writer_processes: Subprocesses for async image writing.
image_writer_threads: Threads for async image writing.
streaming_encoding: If ``True``, encode video in real-time during
capture.
encoder_queue_maxsize: Max buffered frames per camera for streaming.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns:
A :class:`LeRobotDataset` in write mode, ready to append episodes.
@@ -774,7 +810,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
"the shared cache. Please provide a local directory path."
)
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj._requested_root = Path(root)
@@ -783,11 +818,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
if obj._requested_root is not None:
obj._requested_root.mkdir(exist_ok=True, parents=True)
@@ -796,21 +829,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.meta = LeRobotDatasetMetadata(
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
)
obj._encoder_threads = encoder_threads
obj.root = obj.meta.root
# Reader is lazily created on first access (write-only mode)
obj.reader = None
# Create writer for appending
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
vcodec=vcodec,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,

View File

@@ -0,0 +1,209 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
Checks degrade to a no-op when the target codec isn't available locally.
"""
import functools
import logging
from typing import Any
import av
import numpy as np
logger = logging.getLogger(__name__)
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
height, width = src.shape
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
if fill_value is not None:
dst.fill(fill_value)
dst[:, :width] = src
@functools.cache
def get_pix_fmt_channels(pix_fmt: str) -> int:
"""Return the number of components (channels) for *pix_fmt*."""
return len(av.VideoFormat(pix_fmt).components)
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
try:
return av.codec.Codec(vcodec, "w")
except Exception:
return None
@functools.cache
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
codec = get_codec(vcodec)
if codec is None:
return {}
return {opt.name: opt for opt in codec.descriptor.options}
@functools.cache
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
codec = get_codec(vcodec)
if codec is None:
return ()
return tuple(fmt.name for fmt in (codec.video_formats or []))
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
Each name is probed directly via :func:`get_codec`; input order is preserved.
"""
if isinstance(encoders, str):
encoders = [encoders]
available: list[str] = []
for name in encoders:
codec = get_codec(name)
if codec is not None and codec.type == "video":
available.append(name)
else:
logger.debug("encoder '%s' not available as video encoder", name)
return available
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
type_name = opt.type.name
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
if isinstance(value, bool):
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
elif isinstance(value, str):
try:
num_val = float(value)
except ValueError as e:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
) from e
elif isinstance(value, (float, int)):
num_val = value
else:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
# Check integer type compatibility
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
raise ValueError(
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
)
# Check numeric range compatibility
lo, hi = float(opt.min), float(opt.max)
if lo < hi and not (lo <= num_val <= hi):
raise ValueError(
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
)
elif type_name == "STRING":
if isinstance(value, bool):
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
if isinstance(value, str):
str_val = value
elif isinstance(value, (int, float)):
str_val = str(value)
else:
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
# Check string choice compatibility
choices = [c.name for c in (opt.choices or [])]
if choices and str_val not in choices:
raise ValueError(
f"{label}={str_val!r} is not a supported choice for codec "
f"{vcodec!r}; valid choices: {choices}"
)
else:
return
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
formats = _get_codec_video_formats(vcodec)
if formats and pix_fmt not in formats:
raise ValueError(
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
f"supported pixel formats: {list(formats)}"
)
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
"""Ensure *pix_fmt* can carry at least *channels* components."""
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
if pix_fmt_channels < channels:
raise ValueError(
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
f"but the source data has {channels} channel(s)."
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec)
for key, value in codec_options.items():
# GOP size is not a codec-specific option, it has to be validated separately.
if key == "g":
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
continue
if key not in supported_options:
continue
_check_option_value(vcodec, key, value, supported_options[key])
def check_video_encoder_parameters_pyav(
vcodec: str,
pix_fmt: str,
codec_options: dict[str, Any],
channels: int | None = None,
) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict).
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises:
ValueError: on the first incompatibility encountered.
"""
options = _get_codec_options_by_name(vcodec)
if not options:
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
_check_pixel_format(vcodec, pix_fmt)
if channels is not None:
_check_pix_fmt_channels(pix_fmt, channels)
_check_codec_options(vcodec, codec_options)

View File

@@ -93,6 +93,7 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"

View File

@@ -22,7 +22,7 @@ import shutil
import tempfile
import threading
import warnings
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from fractions import Fraction
from pathlib import Path
from threading import Lock
@@ -33,90 +33,22 @@ import fsspec
import numpy as np
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
from lerobot.utils.import_utils import get_safe_default_codec
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import quantize_depth
from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
"h264_videotoolbox", # macOS
"hevc_videotoolbox", # macOS
"h264_nvenc", # NVIDIA GPU
"hevc_nvenc", # NVIDIA GPU
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
def _get_codec_options(
vcodec: str,
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
) -> dict:
"""Build codec-specific options dict for video encoding."""
options = {}
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
options["g"] = str(g)
# Quality control (codec-specific parameter names)
if crf is not None:
if vcodec in ("h264", "hevc", "libsvtav1"):
options["crf"] = str(crf)
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
quality = max(1, min(100, int(100 - crf * 2)))
options["q:v"] = str(quality)
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
options["rc"] = "constqp"
options["qp"] = str(crf)
elif vcodec in ("h264_vaapi",):
options["qp"] = str(crf)
elif vcodec in ("h264_qsv",):
options["global_quality"] = str(crf)
# Preset (only for libsvtav1)
if vcodec == "libsvtav1":
options["preset"] = str(preset) if preset is not None else "12"
return options
def detect_available_hw_encoders() -> list[str]:
"""Probe PyAV/FFmpeg for available hardware video encoders."""
available = []
for codec_name in HW_ENCODERS:
try:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
return available
def resolve_vcodec(vcodec: str) -> str:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logger.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
return encoder
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
def decode_video_frames(
video_path: Path | str,
@@ -124,6 +56,7 @@ def decode_video_frames(
tolerance_s: float,
backend: str | None = None,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
@@ -132,7 +65,9 @@ def decode_video_frames(
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
accepted for one release as an alias for "pyav" and will be removed in a future version.
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
@@ -141,89 +76,101 @@ def decode_video_frames(
Currently supports torchcodec on cpu and pyav.
"""
if backend != "pyav" and is_depth:
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
# We do not actually return uint8 here, but we avoid the 255 normalization step.
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
if backend is None:
backend = get_safe_default_codec()
backend = get_safe_default_video_backend()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
)
elif backend == "pyav":
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "video_reader":
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
def decode_video_frames_pyav(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
"""Loads frames associated to the requested timestamps of a video using PyAV.
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
This is the fallback decoder for platforms where torchcodec has no wheel (currently macOS
x86_64 and linux armv7l — see the torchcodec block in pyproject.toml for the full matrix).
On supported platforms, prefer `decode_video_frames_torchcodec`, which is faster and supports
accurate seek.
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
For more info on video decoding, see `benchmark/video/README.md`
PyAV doesn't support accurate seek: we seek to the nearest preceding keyframe and decode
forward until we have covered the requested timestamp range. The number of key frames in a
video can be adjusted at encoding time to trade off decoding speed against file size.
See torchvision doc for more info on these two backends:
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
Args:
video_path: Path to the video file.
timestamps: List of timestamps (in seconds) to extract frames for.
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
decoded frame.
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
[0, 1] range.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
Returns:
torch.Tensor of shape (len(timestamps), C, H, W).
"""
video_path = str(video_path)
# set backend
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(video_path, "video")
video_path = str(video_path)
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = min(timestamps)
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)
loaded_frames: list[torch.Tensor] = []
loaded_ts: list[float] = []
# load all frames until last requested frame
loaded_frames = []
loaded_ts = []
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
# Seek + decode. `container.seek(offset)` with no `stream` argument expects the offset in
# av.time_base units (microseconds). `backward=True` lands us on the nearest keyframe at or
# before `first_ts`, so we can then decode forward until we cover `last_ts`. See:
# https://pyav.basswood-io.com/docs/stable/api/container.html#av.container.InputContainer.seek
with av.open(video_path) as container:
stream = container.streams.video[0]
container.seek(int(first_ts * av.time_base), backward=True)
if backend == "pyav":
reader.container.close()
for frame in container.decode(stream):
if frame.pts is None:
continue
current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
if is_depth:
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
else:
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
# Convert to CHW uint8 to match torchcodec's output layout.
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
reader = None
if not loaded_frames:
raise FrameTimestampError(
f"No frames could be decoded from {video_path} in the timestamp range [{first_ts}, {last_ts}]."
)
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
loaded_ts_t = torch.tensor(loaded_ts)
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
@@ -234,14 +181,14 @@ def decode_video_frames_torchvision(
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nloaded timestamps: {loaded_ts_t}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
f"\nbackend: pyav"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
closest_ts = loaded_ts_t[argmin_]
if log_loaded_timestamps:
logger.info(f"{closest_ts=}")
@@ -282,7 +229,11 @@ class VideoDecoderCache:
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0]
@@ -400,18 +351,17 @@ def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
video_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
*,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
preset: int | None = None,
encoder_threads: int | None = None,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
vcodec = resolve_vcodec(vcodec)
if video_encoder is None:
video_encoder = camera_encoder_defaults()
vcodec = video_encoder.vcodec
pix_fmt = video_encoder.pix_fmt
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
@@ -422,42 +372,19 @@ def encode_video_frames(
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logger.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
# Get input frames
template = "frame-" + ("[0-9]" * 6) + ".png"
suffix = ".png" if not isinstance(video_encoder, DepthEncoderConfig) else ".tiff"
template = "frame-" + ("[0-9]" * 6) + suffix
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
# Define video codec options
video_options = _get_codec_options(vcodec, g, crf, preset)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if encoder_threads is not None:
if vcodec == "libsvtav1":
lp_param = f"lp={encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(encoder_threads)
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
# Set logging level
if log_level is not None:
@@ -494,7 +421,10 @@ def encode_video_frames(
def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
input_video_paths: list[Path | str],
output_video_path: Path,
overwrite: bool = True,
compatibility_check: bool = False,
):
"""
Concatenate multiple video files into a single video file using pyav.
@@ -507,6 +437,7 @@ def concatenate_video_files(
input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
compatibility_check: Whether to check if the input videos are compatible. Default is False.
Note:
- Creates a temporary directory for intermediate files that is cleaned up after use.
@@ -525,6 +456,22 @@ def concatenate_video_files(
if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.")
# This check may be skipped at recording time as videos are encoded with the same encoder config.
if compatibility_check:
reference_video_info = get_video_info(input_video_paths[0])
for input_path in input_video_paths[1:]:
video_info = get_video_info(input_path)
if (
video_info["video.height"] != reference_video_info["video.height"]
or video_info["video.width"] != reference_video_info["video.width"]
or video_info["video.fps"] != reference_video_info["video.fps"]
or video_info["video.codec"] != reference_video_info["video.codec"]
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
):
raise ValueError(
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
)
# Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
@@ -589,11 +536,7 @@ class _CameraEncoderThread(threading.Thread):
self,
video_path: Path,
fps: int,
vcodec: str,
pix_fmt: str,
g: int | None,
crf: int | None,
preset: int | None,
video_encoder: VideoEncoderConfig,
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
@@ -602,11 +545,8 @@ class _CameraEncoderThread(threading.Thread):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self.video_encoder = video_encoder
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
@@ -635,38 +575,42 @@ class _CameraEncoderThread(threading.Thread):
# Sentinel: flush and close
break
# Ensure HWC uint8 numpy array
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if frame_data.dtype != np.uint8:
if not self.is_depth and frame_data.dtype != np.uint8:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
if container is None:
height, width = frame_data.shape[:2]
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
if self.encoder_threads is not None:
if self.vcodec == "libsvtav1":
lp_param = f"lp={self.encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(self.encoder_threads)
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
output_stream.pix_fmt = self.pix_fmt
output_stream = container.add_stream(
self.video_encoder.vcodec,
self.fps,
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
)
output_stream.pix_fmt = self.video_encoder.pix_fmt
output_stream.width = width
output_stream.height = height
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
if not self.is_depth:
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
else:
video_frame = quantize_depth(
frame_data,
depth_min=self.video_encoder.depth_min,
depth_max=self.video_encoder.depth_max,
shift=self.video_encoder.shift,
use_log=self.video_encoder.use_log,
video_backend=self.video_encoder.video_backend,
)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
@@ -724,22 +668,26 @@ class StreamingVideoEncoder:
def __init__(
self,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
Args:
fps: Frames per second for the output videos.
camera_encoder: Video encoder settings applied to all cameras.
When ``None``, :func:`camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global setting).
``None`` lets the codec decide.
queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames.
"""
self.fps = fps
self.vcodec = resolve_vcodec(vcodec)
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads
self.queue_maxsize = queue_maxsize
self.encoder_threads = encoder_threads
self._frame_queues: dict[str, queue.Queue] = {}
self._result_queues: dict[str, queue.Queue] = {}
@@ -750,18 +698,25 @@ class StreamingVideoEncoder:
self._episode_active = False
self._closed = False
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
def start_episode(
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
) -> None:
"""Start encoder threads for a new episode.
Args:
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
temp_dir: Base directory for temporary MP4 files
depth_video_keys: List of video feature keys that carry depth maps (e.g.
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
"""
if self._episode_active:
self.cancel_episode()
self._dropped_frames.clear()
if depth_video_keys is None:
depth_video_keys = []
for video_key in video_keys:
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
result_queue: queue.Queue = queue.Queue(maxsize=1)
@@ -770,18 +725,15 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
vcodec=self.vcodec,
pix_fmt=self.pix_fmt,
g=self.g,
crf=self.crf,
preset=self.preset,
video_encoder=encoder,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
encoder_threads=self.encoder_threads,
encoder_threads=self._encoder_threads,
)
encoder_thread.start()
@@ -986,8 +938,18 @@ def get_audio_info(video_path: Path | str) -> dict:
return audio_info
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
def get_video_info(
video_path: Path | str,
video_encoder: VideoEncoderConfig | None = None,
) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args:
video_path: Path to the encoded video file to probe.
video_encoder: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence — encoder fields are only written for keys
not already populated from the video file itself.
"""
logging.getLogger("libav").setLevel(av.logging.WARNING)
# Getting video stream information
@@ -1004,13 +966,10 @@ def get_video_info(video_path: Path | str) -> dict:
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
# Reset logging level
av.logging.restore_default_callback()
@@ -1018,20 +977,19 @@ def get_video_info(video_path: Path | str) -> dict:
# Adding audio stream information
video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided
if video_encoder is not None:
for field_name, field_value in asdict(video_encoder).items():
# vcodec is already populated from the video stream
if field_name == "vcodec":
continue
video_info.setdefault(f"video.{field_name}", field_value)
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.

View File

@@ -18,13 +18,13 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .sac.configuration_sac import SACConfig as SACConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
@@ -32,21 +32,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
__all__ = [
# Configuration classes
"ACTConfig",
"DiffusionConfig",
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"MultiTaskDiTConfig",
"EO1Config",
"PI0Config",
"PI0FastConfig",
"PI05Config",
"SACConfig",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",

View File

@@ -28,11 +28,12 @@ import torch.nn.functional as F # noqa: N812
import torch.utils.checkpoint
from torch import Tensor
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
from ..pretrained import PreTrainedPolicy
from .configuration_eo1 import EO1Config
if TYPE_CHECKING or _transformers_available:
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration

View File

@@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
@@ -44,6 +43,8 @@ from lerobot.utils.constants import (
)
from lerobot.utils.import_utils import _transformers_available, require_package
from .configuration_eo1 import EO1Config
if TYPE_CHECKING or _transformers_available:
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
else:

View File

@@ -47,12 +47,12 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
from .pretrained import PreTrainedPolicy
from .sac.configuration_sac import SACConfig
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
@@ -88,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -127,10 +127,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
from .sac.modeling_sac import SACPolicy
elif name == "gaussian_actor":
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
return SACPolicy
return GaussianActorPolicy
elif name == "smolvla":
from .smolvla.modeling_smolvla import SmolVLAPolicy
@@ -167,7 +167,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -191,8 +191,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "gaussian_actor":
return GaussianActorConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "groot":
@@ -365,10 +365,10 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from .sac.processor_sac import make_sac_pre_post_processors
elif isinstance(policy_cfg, GaussianActorConfig):
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
processors = make_sac_pre_post_processors(
processors = make_gaussian_actor_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)

View File

@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sac import SACConfig
from .modeling_sac import SACPolicy
from .processor_sac import make_sac_pre_post_processors
from .configuration_gaussian_actor import GaussianActorConfig
from .modeling_gaussian_actor import GaussianActorPolicy
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]

View File

@@ -1,4 +1,4 @@
# !/usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
@@ -75,18 +75,19 @@ class PolicyConfig:
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@PreTrainedConfig.register_subclass("gaussian_actor")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
class GaussianActorConfig(PreTrainedConfig):
"""Gaussian actor configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configures the policy-side (actor + observation encoder) of a Gaussian
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
By default the actor output is a tanh-squashed diagonal Gaussian
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
CLI: ``--policy.type=gaussian_actor``.
"""
# Mapping of feature types to normalization modes
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu"
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True
@@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig):
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Training parameter
# Encoder architecture
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
# Number of steps for online training
online_steps: int = 1000000
# Capacity of the online replay buffer
@@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig):
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
# Network architecture
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters (Gaussian head)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
# Any validation specific to GaussianActor configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
# Default learning rate used to satisfy the abstract ``get_optimizer_preset()``
# contract from ``PreTrainedConfig``. The actual optimizers used during RL
# training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from
# ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass
# this preset.
default_lr = 3e-4
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
"actor": {"lr": default_lr},
"critic": {"lr": default_lr},
"temperature": {"lr": default_lr},
},
)

View File

@@ -15,16 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import asdict
from typing import Literal
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
@@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_sac import SACConfig, is_image_feature
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
class GaussianActorPolicy(
PreTrainedPolicy,
):
config_class = SACConfig
name = "sac"
config_class = GaussianActorConfig
name = "gaussian_actor"
def __init__(
self,
config: SACConfig | None = None,
config: GaussianActorConfig | None = None,
):
super().__init__(config)
config.validate_features()
@@ -54,9 +49,8 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
self._init_temperature()
self._init_discrete_critic()
def get_optim_params(self) -> dict:
optim_params = {
@@ -65,11 +59,7 @@ class SACPolicy(
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["discrete_critic"] = self.discrete_critic.parameters()
return optim_params
def reset(self):
@@ -79,7 +69,9 @@ class SACPolicy(
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
raise NotImplementedError(
"GaussianActorPolicy does not support action chunking. It returns single actions!"
)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -92,360 +84,43 @@ class SACPolicy(
actions, _, _ = self.actor(batch, observations_features)
if self.config.num_discrete_actions is not None:
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
if self.discrete_critic is not None:
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
else:
discrete_action = torch.ones(
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
)
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
"""Actor forward pass: sample actions and return log-probabilities.
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
batch: A flat observation dict, or a training dict containing
``"state"`` (observations) and optionally ``"observation_feature"``
(pre-computed encoder features).
Returns:
Tensor of Q-values from all critics
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_critic": loss_critic}
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def compute_loss_discrete_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
observations = batch.get("state", batch)
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means}
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_critic = GaussianActorObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
"""Initialize policy actor network."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
@@ -455,21 +130,25 @@ class SACPolicy(
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_discrete_critic(self) -> None:
"""Initialize discrete critic network."""
if self.config.num_discrete_actions is None:
self.discrete_critic = None
return
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
# TODO(Khalil): Compile the discrete critic
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
class SACObservationEncoder(nn.Module):
class GaussianActorObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig) -> None:
def __init__(self, config: GaussianActorConfig) -> None:
super().__init__()
self.config = config
self._init_image_layers()
@@ -677,84 +356,6 @@ class MLP(nn.Module):
return self.net(x)
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: float | None = None,
init_final: float | None = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
self,
encoder: SACObservationEncoder,
ensemble: list[CriticHead],
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.critics = nn.ModuleList(ensemble)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class DiscreteCritic(nn.Module):
def __init__(
self,
@@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module):
class Policy(nn.Module):
def __init__(
self,
encoder: SACObservationEncoder,
encoder: GaussianActorObservationEncoder,
network: nn.Module,
action_dim: int,
std_min: float = -5,
@@ -811,7 +412,7 @@ class Policy(nn.Module):
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder: SACObservationEncoder = encoder
self.encoder: GaussianActorObservationEncoder = encoder
self.network = network
self.action_dim = action_dim
self.std_min = std_min
@@ -885,7 +486,7 @@ class Policy(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
image_key = next(key for key in config.input_features if is_image_feature(key))
self.image_enc_layers = nn.Sequential(
@@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
def _load_pretrained_vision_encoder(self, config: SACConfig):
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
"""Set up CNN encoder"""
from transformers import AutoModel

View File

@@ -32,18 +32,18 @@ from lerobot.processor import (
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from .configuration_sac import SACConfig
from .configuration_gaussian_actor import GaussianActorConfig
def make_sac_pre_post_processors(
config: SACConfig,
def make_gaussian_actor_pre_post_processors(
config: GaussianActorConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the SAC policy.
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the SAC policy.
config: The configuration object for the tanh-Gaussian policy.
dataset_stats: A dictionary of statistics for normalization.
Returns:

View File

@@ -441,13 +441,13 @@ class PaliGemmaWithExpertModel(
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,
@@ -662,8 +662,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)

View File

@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@@ -4,7 +4,6 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
@@ -321,6 +320,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
This step normalizes the `transition` object by:
1. Copying `teleop_action` from `info` to `complementary_data`.
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
3. Copying `discrete_penalty` from `info` to `complementary_data`.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -330,6 +330,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
if TELEOP_ACTION_KEY in info:
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
if DISCRETE_PENALTY_KEY in info:
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
if "is_intervention" in info:
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
@@ -348,18 +351,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
Applies a small per-transition cost on the discrete gripper action.
This step penalizes actions that attempt to close an already closed gripper or
open an already open one, based on position thresholds.
Fires only when the commanded action would actually transition the gripper
from one extreme to the other (close-while-open or open-while-closed).
This discourages gripper oscillation while leaving "stay" and saturating-further
commands unpenalized.
Attributes:
penalty: The negative reward value to apply.
max_gripper_pos: The maximum position value for the gripper, used for normalization.
open_threshold: Normalized state below which the gripper is considered "open".
closed_threshold: Normalized state above which the gripper is considered "closed".
"""
penalty: float = -0.01
penalty: float = -0.02
max_gripper_pos: float = 30.0
open_threshold: float = 0.1
closed_threshold: float = 0.9
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
@@ -379,11 +388,15 @@ class GripperPenaltyProcessorStep(ProcessorStep):
if raw_joint_positions is None:
return new_transition
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
current_gripper_pos = raw_joint_positions.get(f"{GRIPPER_KEY}.pos", None)
if current_gripper_pos is None:
return new_transition
# Gripper action is a PolicyAction at this stage
# During reset, the transition may not carry any action yet.
if action is None:
return new_transition
# Gripper action is expected as the last action dimension.
gripper_action = action[-1].item()
gripper_action_normalized = gripper_action / self.max_gripper_pos
@@ -391,9 +404,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
# - currently open AND target is closed -> close transition
# - currently closed AND target is open -> open transition
is_open = gripper_state_normalized < self.open_threshold
is_closed = gripper_state_normalized > self.closed_threshold
cmd_close = gripper_action_normalized > self.closed_threshold
cmd_open = gripper_action_normalized < self.open_threshold
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
@@ -409,11 +426,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the penalty value and max gripper position.
A dictionary containing the penalty value, max gripper position,
and the open/closed thresholds.
"""
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
"open_threshold": self.open_threshold,
"closed_threshold": self.closed_threshold,
}
def reset(self) -> None:

View File

@@ -134,6 +134,24 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
(already ``(C, 1, 1)``). Needed by RL training, which can start without
a dataset and supplies stats manually via JSON config.
"""
for key, feature in self.features.items():
if feature.type != FeatureType.VISUAL:
continue
if key not in self._tensor_stats:
continue
for stat_name, stat_tensor in self._tensor_stats[key].items():
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
continue
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -152,6 +170,7 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -201,6 +220,7 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -211,6 +231,7 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats

View File

@@ -30,7 +30,7 @@ class RewardClassifierConfig(RewardModelConfig):
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
model_name: str = "lerobot/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2

View File

@@ -17,10 +17,11 @@ import logging
import torch
from torch import Tensor, nn
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.constants import OBS_IMAGE, REWARD
from ..pretrained import PreTrainedRewardModel
from .configuration_classifier import RewardClassifierConfig
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
@@ -105,6 +106,7 @@ class Classifier(PreTrainedRewardModel):
def __init__(
self,
config: RewardClassifierConfig,
**kwargs,
):
from transformers import AutoModel

View File

@@ -25,7 +25,8 @@ from lerobot.processor import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from .configuration_classifier import RewardClassifierConfig
def make_classifier_processor(

View File

@@ -22,9 +22,10 @@ import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
from .classifier.configuration_classifier import RewardClassifierConfig
from .pretrained import PreTrainedRewardModel
from .sarm.configuration_sarm import SARMConfig
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:

View File

@@ -58,9 +58,10 @@ import torch
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
from .modeling_sarm import SARMRewardModel
from .processor_sarm import make_sarm_pre_post_processors
from .sarm_utils import normalize_stage_tau
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:

View File

@@ -32,13 +32,14 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
from lerobot.rewards.sarm.sarm_utils import (
from lerobot.utils.constants import OBS_STR
from ..pretrained import PreTrainedRewardModel
from .configuration_sarm import SARMConfig
from .sarm_utils import (
normalize_stage_tau,
pad_state_to_max_dim,
)
from lerobot.utils.constants import OBS_STR
class StageTransformer(nn.Module):

View File

@@ -58,15 +58,16 @@ from lerobot.processor import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
from lerobot.rewards.sarm.sarm_utils import (
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from .configuration_sarm import SARMConfig
from .sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
find_stage_and_tau,
pad_state_to_max_dim,
)
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep):

View File

@@ -12,23 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reinforcement learning modules.
"""Reinforcement learning modules.
Requires: ``pip install 'lerobot[hilserl]'``
Available modules (import directly)::
from lerobot.rl.actor import ...
from lerobot.rl.learner import ...
from lerobot.rl.learner_service import ...
from lerobot.rl.buffer import ...
from lerobot.rl.eval_policy import ...
from lerobot.rl.gym_manipulator import ...
Distributed actor / learner entry points (``actor``, ``learner``,
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
buffer, data sources and trainer are gRPC-free and usable standalone.
"""
from lerobot.utils.import_utils import require_package
from .algorithms.base import RLAlgorithm as RLAlgorithm
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
from .algorithms.factory import (
make_algorithm as make_algorithm,
make_algorithm_config as make_algorithm_config,
)
from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
from .buffer import ReplayBuffer as ReplayBuffer
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
from .trainer import RLTrainer as RLTrainer
require_package("grpcio", extra="hilserl", import_name="grpc")
__all__: list[str] = []
__all__ = [
"RLAlgorithm",
"RLAlgorithmConfig",
"TrainingStats",
"make_algorithm",
"make_algorithm_config",
"SACAlgorithmConfig",
"RLTrainer",
"ReplayBuffer",
"DataMixer",
"OnlineOfflineMixer",
]

View File

@@ -49,39 +49,53 @@ https://github.com/michel-aractingi/lerobot-hilserl-guide
import logging
import os
import time
from collections.abc import Generator
from functools import lru_cache
from queue import Empty
from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import _grpc_available, require_package
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.types import TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.utils.utils import (
@@ -89,19 +103,24 @@ from lerobot.utils.utils import (
init_logging,
)
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .gym_manipulator import (
create_transition,
make_processors,
make_robot_env,
reset_and_build_transition,
step_env_and_process_transition,
)
from .queue import get_last_item_from_queue
from .train_rl import TrainRLServerPipelineConfig
# Main entry point
@parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate()
display_pid = False
if not use_threads(cfg):
@@ -212,7 +231,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
def act_with_policy(
cfg: TrainRLServerPipelineConfig,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
@@ -252,22 +271,24 @@ def act_with_policy(
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### To avoid sending a policy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
policy = policy.to(device).eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Build the algorithm
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
transition = reset_and_build_transition(online_env, env_processor, action_processor)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -291,8 +312,17 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation)
# Unnormalize only the continuous part.
if cfg.policy.num_discrete_actions is not None:
continuous_action = postprocessor.process_action(action[..., :-1])
discrete_action = action[..., -1:].to(
device=continuous_action.device, dtype=continuous_action.dtype
)
action = torch.cat([continuous_action, discrete_action], dim=-1)
else:
action = postprocessor.process_action(action)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -326,7 +356,8 @@ def act_with_policy(
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
if is_intervention:
episode_intervention = True
episode_intervention_steps += 1
@@ -334,6 +365,7 @@ def act_with_policy(
"discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
),
TeleopEvents.IS_INTERVENTION.value: is_intervention,
}
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
@@ -354,7 +386,7 @@ def act_with_policy(
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -390,14 +422,7 @@ def act_with_policy(
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
transition = reset_and_build_transition(online_env, env_processor, action_processor)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
@@ -408,10 +433,10 @@ def act_with_policy(
def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
stub: "services_pb2_grpc.LearnerServiceStub",
shutdown_event: Any, # Event
attempts: int = 30,
):
) -> bool:
"""Establish a connection with the learner.
Args:
@@ -441,12 +466,14 @@ def establish_learner_connection(
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
"""
Returns a client for the learner service.
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
"""Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
Returns:
tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel.
"""
channel = grpc.insecure_channel(
@@ -461,16 +488,18 @@ def learner_service_client(
def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Receive parameters from the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
@@ -513,12 +542,11 @@ def receive_policy(
def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends transitions to the learner.
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send transitions to the learner.
This function continuously retrieves messages from the queue and processes:
@@ -526,6 +554,13 @@ def send_transitions(
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
transitions_queue (Queue): The queue to receive the transitions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -563,18 +598,24 @@ def send_transitions(
def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends interactions to the learner.
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
interactions_queue (Queue): The queue to receive the interactions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -613,7 +654,11 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
def transitions_stream(
shutdown_event: Any, # Event
transitions_queue: Queue,
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
@@ -629,10 +674,10 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout:
def interactions_stream(
shutdown_event: Event,
shutdown_event: Any, # Event
interactions_queue: Queue,
timeout: float, # type: ignore
) -> services_pb2.Empty:
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=timeout)
@@ -652,7 +697,8 @@ def interactions_stream(
# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
@@ -667,18 +713,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
algorithm.load_weights(state_dicts, device=device)
# Utilities functions

View File

@@ -0,0 +1,20 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sac import SACAlgorithm, SACAlgorithmConfig
__all__ = [
"SACAlgorithm",
"SACAlgorithmConfig",
]

View File

@@ -0,0 +1,207 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import os
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
from torch.optim import Optimizer
from lerobot.types import BatchType
from lerobot.utils.hub import HubMixin
from .configs import RLAlgorithmConfig, TrainingStats
if TYPE_CHECKING:
from torch import nn
from ..data_sources.data_mixer import DataMixer
T = TypeVar("T", bound="RLAlgorithm")
class RLAlgorithm(HubMixin, abc.ABC):
"""Base for all RL algorithms."""
config_class: type[RLAlgorithmConfig]
name: str
config: RLAlgorithmConfig
@abc.abstractmethod
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One complete training step.
The algorithm calls ``next(batch_iterator)`` as many times as it
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
The iterator is owned by the trainer; the algorithm just consumes
from it.
"""
raise NotImplementedError
def configure_data_iterator(
self,
data_mixer: DataMixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
) -> Iterator[BatchType]:
"""Create the data iterator this algorithm needs.
The default implementation uses the standard ``data_mixer.get_iterator()``.
Algorithms that need specialised sampling should override this method.
"""
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
@abc.abstractmethod
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
"""Build and return the optimizers used during training.
Called once on the learner side after construction.
"""
raise NotImplementedError
def get_optimizers(self) -> dict[str, Optimizer]:
"""Return optimizers for checkpointing / external scheduling."""
return {}
@property
def optimization_step(self) -> int:
"""Current learner optimization step.
Part of the stable contract for checkpoint/resume. Algorithms can
either use this default storage or override for custom behavior.
"""
return getattr(self, "_optimization_step", 0)
@optimization_step.setter
def optimization_step(self, value: int) -> None:
self._optimization_step = int(value)
def get_weights(self) -> dict[str, Any]:
"""Policy state-dict to push to actors."""
return {}
@abc.abstractmethod
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner."""
raise NotImplementedError
@abc.abstractmethod
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Must return a flat tensor mapping for everything the algorithm owns
that is not part of the policy (e.g. critic ensembles, target networks,
temperature parameters). Algorithms with no training-only tensors
should explicitly return an empty dict.
"""
raise NotImplementedError
@abc.abstractmethod
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
Implementations MUST keep the identity of any ``nn.Parameter`` that an
optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()``
rather than rebinding the attribute.
"""
raise NotImplementedError
def _save_pretrained(self, save_directory: Path) -> None:
"""Persist the algorithm's tensors and config to ``save_directory``.
Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`)
and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`).
"""
tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()}
save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE))
self.config._save_pretrained(save_directory)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
policy: nn.Module,
config: RLAlgorithmConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
device: str | torch.device = "cpu",
**algo_kwargs: Any,
) -> T:
"""Build an algorithm and load its weights from ``pretrained_name_or_path``."""
if config is None:
config = cls.config_class.from_pretrained(
pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if hasattr(config, "policy_config"):
config.policy_config = policy.config
instance = cls(policy=policy, config=config, **algo_kwargs)
model_id = str(pretrained_name_or_path)
if os.path.isdir(model_id):
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
tensors = load_safetensors(model_file)
instance.load_state_dict(tensors, device=device)
return instance

View File

@@ -0,0 +1,138 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="RLAlgorithmConfig")
logger = logging.getLogger(__name__)
@dataclass
class TrainingStats:
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
losses: dict[str, float] = field(default_factory=dict)
grad_norms: dict[str, float] = field(default_factory=dict)
extra: dict[str, float] = field(default_factory=dict)
def to_log_dict(self) -> dict[str, float]:
"""Flatten all stats into a single dict for logging."""
d: dict[str, float] = {}
for name, val in self.losses.items():
d[name] = val
for name, val in self.grad_norms.items():
d[f"{name}_grad_norm"] = val
for name, val in self.extra.items():
d[name] = val
return d
@dataclass
class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""Registry for algorithm configs."""
@property
def type(self) -> str:
"""Registered name of this algorithm config (e.g. ``"sac"``)."""
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@classmethod
@abc.abstractmethod
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
"""Build an algorithm config from a policy config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
def _save_pretrained(self, save_directory: Path) -> None:
"""Serialize this config as ``config.json`` inside ``save_directory``."""
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**algo_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with draccus.config_type("json"):
instance = draccus.parse(RLAlgorithmConfig, config_file, args=[])
if cls is not RLAlgorithmConfig and not isinstance(instance, cls):
raise TypeError(
f"Config at {model_id} has type '{instance.type}' but was loaded via "
f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()."
)
for key, value in algo_kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
return instance

View File

@@ -0,0 +1,99 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from .base import RLAlgorithm
from .configs import RLAlgorithmConfig
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
Args:
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
**kwargs: Keyword arguments forwarded to the config class constructor.
Returns:
An instance of the matching ``RLAlgorithmConfig`` subclass.
Raises:
ValueError: If ``algorithm_type`` is not registered.
"""
try:
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
except KeyError as err:
raise ValueError(
f"Algorithm type '{algorithm_type}' is not registered. "
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
) from err
return cls(**kwargs)
def get_algorithm_class(name: str) -> type[RLAlgorithm]:
"""
Retrieves an RL algorithm class by its registered name.
This function uses dynamic imports to avoid loading all algorithm classes into
memory at once, improving startup time and reducing dependencies.
Args:
name: The name of the algorithm. Supported names are "sac".
Returns:
The algorithm class corresponding to the given name.
Raises:
ValueError: If the algorithm name is not recognized.
"""
if name == "sac":
from .sac.sac_algorithm import SACAlgorithm
return SACAlgorithm
raise ValueError(
f"Algorithm type '{name}' is not available. "
f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}"
)
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
"""
Instantiate an RL algorithm.
This factory function looks up the :class:`RLAlgorithm` subclass that matches
``cfg.type`` and instantiates it with the provided policy. It also enforces
that ``cfg.policy_config`` has been populated before construction (this is
normally handled by :meth:`TrainRLServerPipelineConfig.validate`).
Args:
cfg: The algorithm configuration. Must have ``policy_config`` set.
policy: The policy module the algorithm will train.
Returns:
An instantiated :class:`RLAlgorithm`.
Raises:
ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not
registered.
"""
if getattr(cfg, "policy_config", None) is None:
raise ValueError(
f"{type(cfg).__name__}.policy_config is None. "
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
"before calling make_algorithm()."
)
cls = get_algorithm_class(cfg.type)
return cls(policy=policy, config=cfg)

View File

@@ -0,0 +1,18 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sac import SACAlgorithmConfig
from .sac_algorithm import SACAlgorithm
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]

View File

@@ -0,0 +1,99 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
CriticNetworkConfig,
GaussianActorConfig,
)
from ..configs import RLAlgorithmConfig
@RLAlgorithmConfig.register_subclass("sac")
@dataclass
class SACAlgorithmConfig(RLAlgorithmConfig):
"""Soft Actor-Critic (SAC) algorithm configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum
entropy reinforcement learning framework. It learns a policy and a Q-function
simultaneously using experience collected from the environment.
This configuration class contains the algorithm-side hyperparameters: critic
ensemble, target networks, temperature / entropy tuning, and the Bellman
update loop. The policy-side (actor + observation encoder) lives in
:class:`~lerobot.policies.gaussian_actor.GaussianActorConfig` and is
referenced via :attr:`policy_config`.
"""
# Optimizer learning rates
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Bellman update
# Discount factor for the SAC algorithm
discount: float = 0.99
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Critic ensemble
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Temperature / entropy
# Initial temperature value
temperature_init: float = 1.0
# Target entropy for automatic temperature tuning. If ``None``, defaults to
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
# there is a discrete action head).
target_entropy: float | None = None
# Update loop
# Update-to-data ratio. Set to >1 to enable extra critic updates per env step.
utd_ratio: int = 1
# Frequency of policy updates
policy_update_freq: int = 1
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Optimizations
# torch.compile is currently disabled by default
use_torch_compile: bool = False
# Policy config
policy_config: PreTrainedConfig | None = None
@classmethod
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
"""Build an algorithm config with default hyperparameters for a given policy."""
return cls(
policy_config=policy_cfg,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
)

View File

@@ -0,0 +1,672 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from collections.abc import Callable, Iterator
from dataclasses import asdict
from typing import Any
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
DISCRETE_DIMENSION_INDEX,
MLP,
DiscreteCritic,
GaussianActorObservationEncoder,
GaussianActorPolicy,
orthogonal_init,
)
from lerobot.policies.utils import get_device_from_parameters
from lerobot.types import BatchType
from lerobot.utils.constants import ACTION
from lerobot.utils.transition import move_state_dict_to_device
from ..base import RLAlgorithm
from ..configs import TrainingStats
from .configuration_sac import SACAlgorithmConfig
class SACAlgorithm(RLAlgorithm):
"""Soft Actor-Critic. Owns critics, targets, temperature, and loss computation."""
config_class = SACAlgorithmConfig
name = "sac"
def __init__(
self,
policy: GaussianActorPolicy,
config: SACAlgorithmConfig,
):
self.config = config
self.policy_config = config.policy_config
self.policy = policy
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
action_dim = self.policy.config.output_features[ACTION].shape[0]
self._init_critics(action_dim)
self._init_temperature(action_dim)
self._device = torch.device(self.policy.config.device)
self._move_to_device()
def _init_critics(self, action_dim) -> None:
"""Build critic ensemble, targets."""
encoder = self.policy.encoder_critic
heads = [
CriticHead(
input_dim=encoder.output_dim + action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
target_heads = [
CriticHead(
input_dim=encoder.output_dim + action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
# TODO(Khalil): Investigate and fix torch.compile
# NOTE: torch.compile is disabled, policy does not converge when enabled.
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.discrete_critic_target = None
if self.policy_config.num_discrete_actions is not None:
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
"""Build target discrete critic (main network is owned by the policy)."""
discrete_critic_target = DiscreteCritic(
encoder=encoder,
input_dim=encoder.output_dim,
output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO(Khalil): Compile the discrete critic
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
return discrete_critic_target
def _init_temperature(self, continuous_action_dim: int) -> None:
"""Set up temperature parameter (log_alpha) and target entropy."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
total_action_dim = continuous_action_dim + (
1 if self.policy_config.num_discrete_actions is not None else 0
)
self.target_entropy = -total_action_dim / 2
def _move_to_device(self) -> None:
self.policy.to(self._device)
self.critic_ensemble.to(self._device)
self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if self.discrete_critic_target is not None:
self.discrete_critic_target.to(self._device)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def _critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def _discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Run one SAC training step (critic / discrete-critic / actor / temperature).
Pulls ``utd_ratio`` batches from ``batch_iterator``, computes the relevant
losses, backpropagates each, and updates target networks.
Args:
batch_iterator: yields batches each containing
- ``action``: Action tensor
- ``reward``: Reward tensor
- ``state``: Observations tensor dict
- ``next_state``: Next observations tensor dict
- ``done``: Done mask tensor
- ``observation_feature``: Optional pre-computed observation features
- ``next_observation_feature``: Optional pre-computed next observation features
- ``complementary_info`` (optional): per-step extras like discrete penalties
Returns:
TrainingStats with per-component losses and grad norms.
"""
clip = self.config.grad_clip_norm
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch, include_complementary_info=True)
loss_critic = self._compute_loss_critic(fb)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
self.optimizers["critic"].step()
if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
self.optimizers["discrete_critic"].step()
self._update_target_networks()
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch, include_complementary_info=False)
loss_critic = self._compute_loss_critic(fb)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item()
self.optimizers["critic"].step()
stats = TrainingStats(
losses={"loss_critic": loss_critic.item()},
grad_norms={"critic": critic_grad},
)
if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
dc_grad = torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(), max_norm=clip
).item()
self.optimizers["discrete_critic"].step()
stats.losses["loss_discrete_critic"] = loss_dc.item()
stats.grad_norms["discrete_critic"] = dc_grad
if self._optimization_step % self.config.policy_update_freq == 0:
for _ in range(self.config.policy_update_freq):
loss_actor = self._compute_loss_actor(fb)
self.optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad = torch.nn.utils.clip_grad_norm_(
self.policy.actor.parameters(), max_norm=clip
).item()
self.optimizers["actor"].step()
loss_temp = self._compute_loss_temperature(fb)
self.optimizers["temperature"].zero_grad()
loss_temp.backward()
temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item()
self.optimizers["temperature"].step()
stats.losses["loss_actor"] = loss_actor.item()
stats.losses["loss_temperature"] = loss_temp.item()
stats.grad_norms["actor"] = actor_grad
stats.grad_norms["temperature"] = temp_grad
stats.extra["temperature"] = self.temperature
self._update_target_networks()
self._optimization_step += 1
return stats
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
# Extract common components from batch
observations = batch["state"]
actions = batch[ACTION]
observation_features = batch.get("observation_feature")
# Extract critic-specific components
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
next_observation_features = batch.get("next_observation_feature")
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.policy.actor(
next_observations, next_observation_features
)
# 2- compute q targets
q_targets = self._critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.policy_config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self._critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
actions = batch[ACTION]
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
observation_features = batch.get("observation_feature")
next_observation_features = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self._discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self._discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self._discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
observation_features = batch.get("observation_feature")
actions_pi, log_probs, _ = self.policy.actor(observations, observation_features)
q_preds = self._critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
"""Compute the temperature loss"""
observations = batch["state"]
observation_features = batch.get("observation_feature")
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.policy.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def _update_target_networks(self) -> None:
"""Update target networks with exponential moving average"""
for target_p, p in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
):
target_p.data.copy_(
p.data * self.config.critic_target_update_weight
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
)
if self.policy_config.num_discrete_actions is not None:
for target_p, p in zip(
self.discrete_critic_target.parameters(),
self.policy.discrete_critic.parameters(),
strict=True,
):
target_p.data.copy_(
p.data * self.config.critic_target_update_weight
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
)
def _prepare_forward_batch(
self, batch: BatchType, *, include_complementary_info: bool = True
) -> dict[str, Any]:
observations = batch["state"]
next_observations = batch["next_state"]
observation_features, next_observation_features = self.get_observation_features(
observations, next_observations
)
forward_batch: dict[str, Any] = {
ACTION: batch[ACTION],
"reward": batch["reward"],
"state": observations,
"next_state": next_observations,
"done": batch["done"],
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
if include_complementary_info and "complementary_info" in batch:
forward_batch["complementary_info"] = batch["complementary_info"]
return forward_batch
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
A dictionary mapping component names ("actor", "critic", "temperature")
to their respective Adam optimizers.
"""
actor_params = self.policy.get_optim_params()["actor"]
self.optimizers = {
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
}
if self.policy_config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam(
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
)
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
def get_weights(self) -> dict[str, Any]:
"""Send actor + discrete-critic state dicts."""
state_dicts: dict[str, Any] = {
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
}
if self.policy_config.num_discrete_actions is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
self.policy.discrete_critic.state_dict(), device="cpu"
)
return state_dicts
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load actor + discrete-critic weights into the policy."""
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.policy.discrete_critic.load_state_dict(discrete_sd)
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Encoder weights are stripped because they are owned by the policy
(``policy.encoder_critic``) and already saved via ``policy.save_pretrained``.
"""
bundle: dict[str, torch.Tensor] = {}
for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items():
bundle[f"critic_ensemble.{k}"] = v
for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items():
bundle[f"critic_target.{k}"] = v
if self.discrete_critic_target is not None:
for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items():
bundle[f"discrete_critic_target.{k}"] = v
bundle["log_alpha"] = self.log_alpha.detach()
return bundle
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
``log_alpha`` is restored via ``Parameter.data.copy_`` so the
``temperature`` optimizer's reference to the parameter object stays
valid after resume.
"""
critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.")
critic_target_state = _split_prefix(state_dict, "critic_target.")
self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False)
self.critic_target.load_state_dict(critic_target_state, strict=False)
if self.discrete_critic_target is not None:
discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.")
self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False)
if "log_alpha" in state_dict:
self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device))
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
) -> tuple[Tensor | None, Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = self.policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Drop ``encoder.*`` keys from a critic-module state dict."""
return {k: v for k, v in state.items() if not k.startswith("encoder.")}
def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
"""Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped."""
return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)}
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: float | None = None,
init_final: float | None = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (GaussianActorObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
self,
encoder: GaussianActorObservationEncoder,
ensemble: list[CriticHead],
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.critics = nn.ModuleList(ensemble)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values

View File

@@ -97,8 +97,8 @@ class ReplayBuffer:
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
state_keys (list[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Callable | None): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
@@ -634,7 +634,7 @@ class ReplayBuffer:
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
transitions (list[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
if state_keys is None:

View File

@@ -176,11 +176,11 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
crop_params_dict (dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
resize_size (tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:

View File

@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.types import BatchType
from .data_mixer import DataMixer, OnlineOfflineMixer
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]

View File

@@ -0,0 +1,97 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from lerobot.types import BatchType
from ..buffer import ReplayBuffer, concatenate_batch_transitions
class DataMixer(abc.ABC):
"""Abstract interface for all data mixing strategies."""
@abc.abstractmethod
def sample(self, batch_size: int) -> BatchType:
"""Draw one batch of ``batch_size`` transitions."""
raise NotImplementedError
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Infinite iterator that yields batches."""
while True:
yield self.sample(batch_size)
class OnlineOfflineMixer(DataMixer):
"""Mixes transitions from an online and an offline replay buffer."""
def __init__(
self,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer | None = None,
online_ratio: float = 1.0,
):
if not 0.0 <= online_ratio <= 1.0:
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
self.online_buffer = online_buffer
self.offline_buffer = offline_buffer
self.online_ratio = online_ratio
def sample(self, batch_size: int) -> BatchType:
if self.offline_buffer is None:
return self.online_buffer.sample(batch_size)
n_online = max(1, int(batch_size * self.online_ratio))
n_offline = batch_size - n_online
online_batch = self.online_buffer.sample(n_online)
offline_batch = self.offline_buffer.sample(n_offline)
return concatenate_batch_transitions(online_batch, offline_batch)
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Yield batches by composing buffer async iterators."""
n_online = max(1, int(batch_size * self.online_ratio))
online_iter = self.online_buffer.get_iterator(
batch_size=n_online,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
if self.offline_buffer is None:
yield from online_iter
return
n_offline = batch_size - n_online
offline_iter = self.offline_buffer.get_iterator(
batch_size=n_offline,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
while True:
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))

View File

@@ -17,7 +17,6 @@ import logging
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_policy
from lerobot.robots import ( # noqa: F401
@@ -31,6 +30,7 @@ from lerobot.teleoperators import (
)
from .gym_manipulator import make_robot_env
from .train_rl import TrainRLServerPipelineConfig
logging.basicConfig(level=logging.INFO)

View File

@@ -74,6 +74,7 @@ from lerobot.teleoperators import (
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.import_utils import require_package
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
@@ -312,6 +313,7 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
# Check if this is a GymHIL simulation environment
if cfg.name == "gym_hil":
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
require_package("gym-hil", extra="hilserl", import_name="gym_hil")
import gym_hil # noqa: F401
# Extract gripper settings with defaults
@@ -383,10 +385,21 @@ def make_processors(
GymHILAdapterProcessorStep(),
Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
# Add time limit processor if reset config exists
if cfg.processor.reset is not None:
env_pipeline_steps.append(
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
)
env_pipeline_steps.extend(
[
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
)
return DataProcessorPipeline(
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
), DataProcessorPipeline(
@@ -551,8 +564,19 @@ def step_env_and_process_transition(
terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
# Merge env and action-processor info: env wins for str keys, action-processor
# wins for `TeleopEvents` enum keys
action_info = processed_action_transition[TransitionKey.INFO]
new_info = info.copy()
new_info.update(processed_action_transition[TransitionKey.INFO])
for key, value in action_info.items():
if isinstance(key, TeleopEvents):
new_info[key] = value
new_transition = create_transition(
observation=obs,
@@ -568,6 +592,24 @@ def step_env_and_process_transition(
return new_transition
def reset_and_build_transition(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
) -> EnvTransition:
"""Reset env + processors and return the first env-processed transition."""
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
complementary_data: dict[str, Any] = {}
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
return env_processor(data=transition)
def control_loop(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
@@ -593,17 +635,7 @@ def control_loop(
print("- When not intervening, robot will stay still")
print("- Press Ctrl+C to exit")
# Reset environment and processors
obs, info = env.reset()
complementary_data = (
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
)
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(data=transition)
transition = reset_and_build_transition(env, env_processor, action_processor)
# Determine if gripper is used
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
@@ -659,79 +691,82 @@ def control_loop(
episode_step = 0
episode_start_time = time.perf_counter()
while episode_idx < cfg.dataset.num_episodes_to_record:
step_start_time = time.perf_counter()
try:
while episode_idx < cfg.dataset.num_episodes_to_record:
step_start_time = time.perf_counter()
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Use the new step function
transition = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
env_processor=env_processor,
action_processor=action_processor,
)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {
observation = {
k: v.squeeze(0).cpu()
for k, v in transition[TransitionKey.OBSERVATION].items()
if isinstance(v, torch.Tensor)
}
# Use teleop_action if available, otherwise use the action from the transition
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
transition = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
env_processor=env_processor,
action_processor=action_processor,
)
frame = {
**observations,
ACTION: action_to_record.cpu(),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if dataset is not None:
frame["task"] = cfg.dataset.task
dataset.add_frame(frame)
if cfg.mode == "record":
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
)
frame = {
**observation,
ACTION: action_to_record.cpu(),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"discrete_penalty", 0.0
)
frame["complementary_info.discrete_penalty"] = np.array(
[discrete_penalty], dtype=np.float32
)
episode_step += 1
if dataset is not None:
frame["task"] = cfg.dataset.task
dataset.add_frame(frame)
# Handle episode termination
if terminated or truncated:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
episode_step = 0
episode_idx += 1
episode_step += 1
if dataset is not None:
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
# Handle episode termination
if terminated or truncated:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
episode_step = 0
episode_idx += 1
# Reset for new episode
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
if dataset is not None:
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# Reset for new episode
transition = reset_and_build_transition(env, env_processor, action_processor)
# Maintain fps timing
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
# Maintain fps timing
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
finally:
if dataset is not None and dataset.writer is not None and dataset.writer.image_writer is not None:
logging.info("Waiting for image writer to finish...")
dataset.writer.image_writer.stop()
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Finalizing dataset before pushing to hub")

View File

@@ -51,9 +51,21 @@ import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pformat
from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import _grpc_available, require_package
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2_grpc
else:
grpc = None
services_pb2_grpc = None
import grpc
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
@@ -68,14 +80,11 @@ from lerobot.common.train_utils import (
)
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
@@ -84,26 +93,35 @@ from lerobot.transport.utils import (
)
from lerobot.utils.constants import (
ACTION,
ALGORITHM_DIR,
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
init_logging,
)
from .buffer import ReplayBuffer, concatenate_batch_transitions
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer
from .data_sources import OnlineOfflineMixer
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .train_rl import TrainRLServerPipelineConfig
from .trainer import RLTrainer
@parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg):
import torch.multiprocessing as mp
@@ -179,7 +197,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
def start_learner_threads(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
) -> None:
"""
Start the learner threads for training.
@@ -253,7 +271,7 @@ def start_learner_threads(
def add_actor_information_and_train(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
@@ -266,8 +284,8 @@ def add_actor_information_and_train(
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Delegates training updates to an ``RLAlgorithm``.
- Periodically pushes updated weights to actors.
- Logs training statistics, including loss values and optimization frequency.
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
@@ -286,17 +304,13 @@ def add_actor_information_and_train(
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm
online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
fps = cfg.env.fps
log_freq = cfg.log_freq
save_freq = cfg.save_freq
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing
if not use_threads(cfg):
@@ -308,7 +322,7 @@ def add_actor_information_and_train(
logging.info("Initializing policy")
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
@@ -317,15 +331,17 @@ def add_actor_information_and_train(
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
# Push initial policy weights to actors
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
@@ -338,21 +354,37 @@ def add_actor_information_and_train(
device=device,
storage_device=storage_device,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# DataMixer: online-only or online/offline 50-50 mix
data_mixer = OnlineOfflineMixer(
online_buffer=replay_buffer,
offline_buffer=offline_replay_buffer,
online_ratio=cfg.online_ratio,
)
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=batch_size,
preprocessor=preprocessor,
)
# If we are resuming, we need to load the training state
optimizers = algorithm.get_optimizers()
resume_optimization_step, resume_interaction_step = load_training_state(
cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device
)
logging.info("Starting learner thread")
interaction_message = None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
algorithm.optimization_step = optimization_step
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
dataset_repo_id = None
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@@ -365,7 +397,6 @@ def add_actor_information_and_train(
transition_queue=transition_queue,
replay_buffer=replay_buffer,
offline_replay_buffer=offline_replay_buffer,
device=device,
dataset_repo_id=dataset_repo_id,
shutdown_event=shutdown_event,
)
@@ -382,180 +413,20 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["discrete_critic"].step()
# Update target networks (main and discrete)
policy.update_target_networks()
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
critic_output = policy.forward(forward_batch, model="critic")
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
# Initialize training info dictionary
training_infos = {
"loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["discrete_critic"].step()
# Add discrete critic info to training info
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Actor optimization
actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad()
loss_temperature.backward()
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
stats = trainer.training_step()
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
last_time_policy_pushed = time.time()
# Update target networks (main and discrete)
policy.update_target_networks()
training_infos = stats.to_log_dict()
# Log training metrics at specified intervals
optimization_step = algorithm.optimization_step
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
@@ -583,7 +454,6 @@ def add_actor_information_and_train(
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
@@ -597,9 +467,12 @@ def add_actor_information_and_train(
policy=policy,
optimizers=optimizers,
replay_buffer=replay_buffer,
algorithm=algorithm,
offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id,
fps=fps,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
@@ -607,7 +480,7 @@ def start_learner(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
cfg: TrainRLServerPipelineConfig,
):
"""
@@ -681,9 +554,12 @@ def save_training_checkpoint(
policy: nn.Module,
optimizers: dict[str, Optimizer],
replay_buffer: ReplayBuffer,
algorithm: RLAlgorithm | None = None,
offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None,
fps: int = 30,
preprocessor=None,
postprocessor=None,
) -> None:
"""
Save training checkpoint and associated data.
@@ -707,6 +583,8 @@ def save_training_checkpoint(
offline_replay_buffer: Optional offline replay buffer to save
dataset_repo_id: Repository ID for dataset
fps: Frames per second for dataset
preprocessor: Optional preprocessor pipeline to save
postprocessor: Optional postprocessor pipeline to save
"""
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
@@ -715,7 +593,7 @@ def save_training_checkpoint(
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
# Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/).
save_checkpoint(
checkpoint_dir=checkpoint_dir,
step=optimization_step,
@@ -723,13 +601,22 @@ def save_training_checkpoint(
policy=policy,
optimizer=optimizers,
scheduler=None,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {"step": optimization_step, "interaction_step": interaction_step}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Algorithm-owned tensors live in their own component subfolder
# so they can be `push_to_hub`'d independently and don't bloat the inference artifact.
if algorithm is not None:
algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR)
# Enrich training_step.json with the RL-specific interaction_step counter so
# both can be restored from a single file.
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
write_json(
{"step": optimization_step, "interaction_step": interaction_step},
training_state_dir / TRAINING_STEP,
)
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
@@ -760,58 +647,6 @@ def save_training_checkpoint(
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_discrete_critic = torch.optim.Adam(
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["discrete_critic"] = optimizer_discrete_critic
return optimizers, lr_scheduler
# Training setup functions
@@ -875,13 +710,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli
def load_training_state(
cfg: TrainRLServerPipelineConfig,
optimizers: Optimizer | dict[str, Optimizer],
algorithm: RLAlgorithm | None = None,
device: str | torch.device = "cpu",
):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Loads the training state (optimizers, RNG, step + interaction step, and
algorithm-owned tensors) from the most recent checkpoint.
Args:
cfg (TrainRLServerPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
cfg: Training configuration.
optimizers: Optimizers to load state into.
algorithm: Algorithm whose state dict should be restored.
Required for full main-equivalent resume;
the policy itself is restored separately via ``make_policy``.
device: Device on which to place loaded algorithm tensors.
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
@@ -890,20 +732,31 @@ def load_training_state(
return None, None
# Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
logging.info(f"Loading training state from {checkpoint_dir}")
try:
# Use the utility function from train_utils which loads the optimizer state
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# Restore optimizers + RNG + step from the standard `training_state/` folder
step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None)
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load
interaction_step = training_state.get("interaction_step", 0)
# Restore algorithm-owned tensors
if algorithm is not None:
algo_dir = checkpoint_dir / ALGORITHM_DIR
if algo_dir.is_dir():
tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE))
algorithm.load_state_dict(tensors, device=device)
logging.info(f"Loaded algorithm state from {algo_dir}")
else:
logging.warning(
f"No algorithm state found at {algo_dir}; "
"will keep their freshly-initialised values. Adam moments restored from the "
"old optimizer state may not match these reset parameters."
)
# Read interaction_step from the enriched training_step.json
training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP
interaction_step = int(load_json(training_step_path).get("interaction_step", 0))
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
@@ -1016,33 +869,6 @@ def initialize_offline_replay_buffer(
# Utilities/Helpers functions
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
return cfg.policy.concurrency.learner == "threads"
@@ -1093,19 +919,11 @@ def check_nan_in_transition(
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
def push_actor_policy_to_queue(parameters_queue: Queue, algorithm: RLAlgorithm) -> None:
logging.debug("[LEARNER] Pushing actor policy to the queue")
# Create a dictionary to hold all the state dicts
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
# Add discrete critic if it exists
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
policy.discrete_critic.state_dict(), device="cpu"
)
logging.debug("[LEARNER] Including discrete critic in state dict push")
state_dicts = algorithm.get_weights()
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
@@ -1129,9 +947,8 @@ def process_transitions(
transition_queue: Queue,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
device: str,
dataset_repo_id: str | None,
shutdown_event: any,
shutdown_event: Any, # Event
):
"""Process all available transitions from the queue.
@@ -1139,7 +956,6 @@ def process_transitions(
transition_queue: Queue for receiving transitions from the actor
replay_buffer: Replay buffer to add transitions to
offline_replay_buffer: Offline replay buffer to add transitions to
device: Device to move transitions to
dataset_repo_id: Repository ID for dataset
shutdown_event: Event to signal shutdown
"""
@@ -1148,8 +964,6 @@ def process_transitions(
transition_list = bytes_to_transitions(buffer=transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition=transition, device=device)
# Skip transitions with NaN values
if check_nan_in_transition(
observations=transition["state"],
@@ -1163,7 +977,7 @@ def process_transitions(
# Add to offline buffer if it's an intervention
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
TeleopEvents.IS_INTERVENTION
TeleopEvents.IS_INTERVENTION.value
):
offline_replay_buffer.add(**transition)
@@ -1172,7 +986,7 @@ def process_interaction_messages(
interaction_message_queue: Queue,
interaction_step_shift: int,
wandb_logger: WandBLogger | None,
shutdown_event: any,
shutdown_event: Any, # Event
) -> dict | None:
"""Process all available interaction messages from the queue.

View File

@@ -18,17 +18,32 @@
import logging
import time
from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.utils.import_utils import _grpc_available
from .queue import get_last_item_from_queue
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
class LearnerService(_ServicerBase):
"""
Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -51,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
self.interaction_message_queue = interaction_message_queue
self.queue_get_timeout = queue_get_timeout
def StreamParameters(self, request, context): # noqa: N802
def StreamParameters( # noqa: N802
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
@@ -86,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.info("[LEARNER] Stream parameters finished")
return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
@@ -100,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving transitions")
return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor")
@@ -114,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving interactions")
return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802
def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
return services_pb2.Empty()

View File

@@ -0,0 +1,50 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Top-level pipeline config for distributed RL training (actor / learner)."""
from __future__ import annotations
from dataclasses import dataclass
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from .algorithms.configs import RLAlgorithmConfig
from .algorithms.factory import make_algorithm_config
from .algorithms.sac import SACAlgorithmConfig # noqa: F401
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm config.
algorithm: RLAlgorithmConfig | None = None
# Data mixer strategy name. Currently supports "online_offline".
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer.
online_ratio: float = 0.5
def validate(self) -> None:
super().validate()
if self.algorithm is None:
self.algorithm = make_algorithm_config("sac")
if getattr(self.algorithm, "policy_config", None) is None:
self.algorithm.policy_config = self.policy

101
src/lerobot/rl/trainer.py Normal file
View File

@@ -0,0 +1,101 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from lerobot.types import BatchType
from .algorithms.base import RLAlgorithm
from .algorithms.configs import TrainingStats
from .data_sources.data_mixer import DataMixer
class RLTrainer:
"""Unified training step orchestrator.
Holds the algorithm, a DataMixer, and an optional preprocessor.
"""
def __init__(
self,
algorithm: RLAlgorithm,
data_mixer: DataMixer,
batch_size: int,
*,
preprocessor: Any | None = None,
):
self.algorithm = algorithm
self.data_mixer = data_mixer
self.batch_size = batch_size
self._preprocessor = preprocessor
self._iterator: Iterator[BatchType] | None = None
self.algorithm.make_optimizers_and_scheduler()
def _build_data_iterator(self) -> Iterator[BatchType]:
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
raw = self.algorithm.configure_data_iterator(
data_mixer=self.data_mixer,
batch_size=self.batch_size,
)
if self._preprocessor is not None:
return _PreprocessedIterator(raw, self._preprocessor)
return raw
def reset_data_iterator(self) -> None:
"""Discard the current iterator so it will be rebuilt lazily next step."""
self._iterator = None
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
"""Swap the active data mixer, optionally resetting the iterator."""
self.data_mixer = data_mixer
if reset:
self.reset_data_iterator()
def training_step(self) -> TrainingStats:
"""Run one training step (algorithm-agnostic)."""
if self._iterator is None:
self._iterator = self._build_data_iterator()
return self.algorithm.update(self._iterator)
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
"""Apply policy preprocessing to RL observations only."""
observations = batch["state"]
next_observations = batch["next_state"]
batch["state"] = preprocessor.process_observation(observations)
batch["next_state"] = preprocessor.process_observation(next_observations)
return batch
class _PreprocessedIterator:
"""Iterator wrapper that preprocesses each sampled RL batch."""
__slots__ = ("_raw", "_preprocessor")
def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None:
self._raw = raw_iterator
self._preprocessor = preprocessor
def __iter__(self) -> _PreprocessedIterator:
return self
def __next__(self) -> BatchType:
batch = next(self._raw)
return preprocess_rl_batch(self._preprocessor, batch)

View File

@@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
clip_min: The minimum allowed gripper joint position.
clip_max: The maximum allowed gripper joint position.
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
discrete_gripper: If True, interpret the input as a discrete class index
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
"""
speed_factor: float = 20.0
@@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
raise ValueError("Joints observation is require for computing robot kinematics")
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_vel = (gripper_vel - 1) * self.clip_max
# Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
# Negation accounts for SO100 sign (joint position increases on close).
# 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
gripper_vel = -(gripper_vel - 1) * self.clip_max
# Compute desired gripper position
delta = gripper_vel * float(self.speed_factor)

View File

@@ -68,9 +68,12 @@ class SOFollower(Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
features: dict[str, tuple] = {}
for cam in self.cameras:
features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
if getattr(self.cameras[cam], "use_depth", False):
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
return features
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -190,6 +193,12 @@ class SOFollower(Robot):
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
if getattr(cam, "use_depth", False):
start = time.perf_counter()
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
return obs_dict
@check_if_not_connected

View File

@@ -332,7 +332,8 @@ def build_rollout_context(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
@@ -367,7 +368,8 @@ def build_rollout_context(
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,

Some files were not shown because too many files have changed in this diff Show More