Compare commits

...

34 Commits

Author SHA1 Message Date
CarolinePascal
237bae51e8 feat(default values): applying a consistent naming convention for default RGB cameras video encoder parameters 2026-05-04 18:05:23 +02:00
CarolinePascal
df8b33fc68 fix(camera_encoder_config): Removing camera_encoder_config from LeRobotDataset, as it's only required in LeRobotDatasetWriter. 2026-05-04 18:00:14 +02:00
CarolinePascal
50e2d7b5f4 chore(doctrings): updating docstrings 2026-05-04 17:01:11 +02:00
CarolinePascal
016799dfa1 chore(format): formatting code 2026-04-30 14:42:37 +02:00
CarolinePascal
51b9038458 chore(PyAV): cleaning up PyAV utils and encoding parameters checks to stick to the minimun required tooling. 2026-04-30 14:31:08 +02:00
CarolinePascal
cc9a2e5c99 chore(format): fixing formatting issues 2026-04-29 16:48:57 +02:00
CarolinePascal
a2376389f9 test(new): adding new tests for encoding related features 2026-04-29 16:48:56 +02:00
CarolinePascal
57a619ab02 test(existing): adapting existing tests 2026-04-29 16:48:56 +02:00
CarolinePascal
7f624adcc5 chore(duplicate): removing duplicate get_codec_options definition 2026-04-29 16:48:56 +02:00
CarolinePascal
375cf1fdf3 feat(pyav checks): making pyav parameters checks more robust 2026-04-29 16:48:56 +02:00
CarolinePascal
b2c2bb7641 feat(VideoEncoderConfig init): making VideoEncoderConfig more robust and adaptable to multiple backends 2026-04-29 16:48:56 +02:00
CarolinePascal
4a87ee1537 fix(concatenation compatibility): adding compatibility check when concatenating video files 2026-04-29 16:48:56 +02:00
CarolinePascal
e44f86e516 feat(metadata): adding encoding parameters in dataset metadata 2026-04-29 16:48:56 +02:00
CarolinePascal
a0e3acdb67 chore(docs): updating the docs 2026-04-29 16:46:16 +02:00
CarolinePascal
38ff579bcc feat(VideoEncoderConfig): propagating the VideoEncoderConfig in the codebase 2026-04-29 16:44:47 +02:00
CarolinePascal
479e444517 feat(VideoEncoderConfig): creating a VideoEncoderConfig to encapsulate encoding parameters 2026-04-29 16:42:14 +02:00
CarolinePascal
9787b8fa26 feat(pyav utils): adding suport for PyAV encoding parameters validation 2026-04-29 16:42:14 +02:00
CarolinePascal
71f39f6912 chore(video backend): renaming codec into video_backend in get_safe_default_video_backend() 2026-04-29 16:42:14 +02:00
Khalil Meftah
b5f65e5332 Expose sarm package API and ship reward model card template (#3477)
* chore: List lerobot_rewardmodel_modelcard_template.md in MANIFEST.in

* chore: export SARMConfig, SARMRewardModel, and make_sarm_pre_post_processors from rewards.sarm.
2026-04-29 16:17:16 +02:00
Khalil Meftah
cd6b43ea7a fix(train): migrate legacy RA-BC fields in train config loading (#3480) 2026-04-29 16:17:00 +02:00
Steven Palma
2236bbe7a3 fix(rollout): propagate policy-specific CLI config paramaters (#3483)
Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
2026-04-29 16:13:10 +02:00
Maxime Ellerbach
cb0a944941 refactor(datasets): replace untyped dict with typed DatasetInfo dataclass (#3472)
* refactor(datasets): replace untyped dict with typed DatasetInfo dataclass

Introduce typed DatasetInfo dataclass to replace untyped dict representation of info.json.

Changes:
- Add DatasetInfo dataclass with explicit fields and validation
- Implement __post_init__ for shape conversion (list ↔ tuple)
- Add dict-style compatibility layer (__getitem__, __setitem__, .get())
- Add from_dict() and to_dict() for JSON serialization
- Update io_utils to use load_info/write_info with DatasetInfo
- Update dataset utilities and metadata to use attribute access
- Remove aggregate.py dict-style field access
- Add tests fixture support for DatasetInfo

Benefits:
- Type safety with IDE auto-completion
- Validation at construction time
- Explicit schema documentation

* fix pre-commit

* update docstring inside DatasetInfo.from_dict()

* sorts the unknown to have deterministic output

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* refactoring the last few old fieds


* fix crop dataset roi type mismatch


* use consistantly int for data and video_files_size_in_mb

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: jjolla93 <jjolla93@gmail.com>
2026-04-28 18:40:30 +02:00
Khalil Meftah
8a3d64033f Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2026-04-28 17:56:24 +02:00
Steven Palma
03ee50e08f chore(ci): bump docs workflows (#3476) 2026-04-28 15:06:44 +02:00
Steven Palma
ca87ccd941 feat(rollout): decouple policy deployment from data recording with new lerobot-rollout CLI (#3413)
* feat(scripts): lerobot-rollout

* fix(rollout) require dataset in dagger + use duration too

* fix(docs): dagger num_episodes

* test(rollout): fix expectations

* fix(rollout): features check

* fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config

* docs(rollout): edit rename_map instructions

* chore(rollout): multiple minor improvements

* chore(rollout): address coments + minor improvements

* fix(rollout): enable default

* fix(tests): default value RTCConfig

* fix(rollout): robot_observation_processor and notify_observation at policy frequency instead of interpolator rate

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): prevent relativeactions with sync inference engine

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): rtc reanchor to non normalized state

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): fixing the episode length to use hwc (#3469)

also reducing default length to 5 minutes

* feat(rollout): go back to initial position is now a config

* fix(rollout): properly propagating video_files_size_in_mb to lerobot_dataset (#3470)

* chore(rollout): note about dagger correction stage

* chore(docs): update comments and docstring

* fix(test): move rtc relative out of rollout module

* fix(rollout): address the review comments

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
2026-04-28 00:57:35 +02:00
Steven Palma
77352c495c chore(dependencies): update uv.lock (#3437)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-04-27 23:15:46 +02:00
Steven Palma
05a5223885 fix(pi): avoid peak RAM in PiGemma construction by freeing replaced submodules (#3454)
Co-Authored-By: Daiki Kamata <daiki.kamata@access-company.com>
Co-Authored-By: Jack Vial <jackvial@users.noreply.github.com>
Co-Authored-By: Ajay Anubolu <AjAnubolu@users.noreply.github.com>
Co-Authored-By: Finn F. <F-Fer@users.noreply.github.com>
2026-04-24 17:50:12 +02:00
Steven Palma
580d818aa9 fix(dataset): no default overwrite in lerobot tool recompute stats (#3452) 2026-04-24 15:07:19 +02:00
Steven Palma
587aa82021 fix(imports): realsense import name is platform dependent (#3451) 2026-04-24 12:55:38 +02:00
Chuyao Shen
12b88fce02 not use dataclass (#3414)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-04-24 11:26:59 +02:00
masato-ka
fc6c94c82a fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in… (#3419)
* fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in CLIP encoding

In transformers 5.x, CLIPModel.get_image_features() and get_text_features()
return BaseModelOutputWithPooling instead of a plain torch.FloatTensor.
Added isinstance check to extract pooler_output when the return value is not
a tensor, maintaining backward compatibility with transformers 4.x.

Fixes AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'detach'

* Adding assertion check for pooler_output of CLIP. This change is response to below comment.
https://github.com/huggingface/lerobot/pull/3419#discussion_r3112594387

* Adding assertion check for pooler_output of CLIP. This change is response to below comment. Change to simple check and rise
https://github.com/huggingface/lerobot/pull/3419#discussion_r3126953776

---------
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-04-23 16:26:58 +02:00
Steven Palma
1add460678 fix(policy): loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT (#3442)
* Fix loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT

When action_is_pad masks out padded timesteps, the subsequent .mean()
still divides by the total element count (including zeroed-out padding),
underestimating the loss. With 60-70% padding this can cut the effective
gradient signal by 2-3x.

Replace mask-then-mean with mask-then-sum / valid-count for all three
affected policies. TDMPC is not affected because it sums over time
before averaging over batch.

Fixes #3353

* linting

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* Update src/lerobot/policies/diffusion/modeling_diffusion.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* apply ACT loss normalization suggestion from review

Divide by num_valid (timesteps * action_dim) instead of just timesteps,
matching the diffusion/multi_task_dit fix. Addresses review from
@whats2000 (https://github.com/huggingface/lerobot/pull/3377#discussion_r3106845791).

* fix(test): update safetensor act

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com>
Co-authored-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
2026-04-23 15:23:54 +02:00
Qi Jia
4587c2b648 fix xvla docs (#3291)
Co-authored-by: Qi Jia <kaufou@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-04-23 14:50:32 +02:00
whats2000
2236cdb302 fix(smolvla): correct loss normalization for padded actions (#3434)
Apply the same per-scalar-mean fix to SmolVLA that #3377 landed for
ACT / Diffusion / MultiTaskDiT. The pre-patch form applies the
`action_is_pad` mask to zero out padded timesteps, then calls `.mean()`
(or `.mean(dim=(1, 2))`). Because `.mean()` divides by the total number
of elements including the zeroed padding, the loss is diluted by the
padding fraction.

Fixed by normalizing only over valid (non-padded) scalar entries:

    num_valid = ((~actions_is_pad).sum(...) * losses.shape[-1]).clamp_min(1)
    loss = losses.sum(...) / num_valid

`clamp_min(1)` preserves the all-padded-batch edge case (0/1 = 0). Both
reduction paths are updated. Behavior when `action_is_pad` is missing is
unchanged (`losses.mean()`).

Empirical A/B on aloha_sim_transfer_cube_human (chunk_size=40, batch=2,
30 steps, fixed seed, GB200) shows `loss_A / loss_B = 0.9672 (±0.088)` —
same direction and magnitude as PR #3377's `loss_A / loss_C ≈ 0.96` for
ACT. Heavier-padding recipes will see a larger gap.

Refs: #3353 (original report for ACT), #3377 (fix for the other three
policies).
2026-04-23 10:34:11 +02:00
156 changed files with 9830 additions and 4114 deletions

View File

@@ -33,7 +33,7 @@ jobs:
github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot' github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
with: with:
package_name: lerobot package_name: lerobot
secrets: secrets:

View File

@@ -55,7 +55,7 @@ jobs:
github.repository == 'huggingface/lerobot' github.repository == 'huggingface/lerobot'
permissions: permissions:
contents: read contents: read
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
with: with:
commit_sha: ${{ github.sha }} commit_sha: ${{ github.sha }}
package: lerobot package: lerobot
@@ -78,7 +78,7 @@ jobs:
permissions: permissions:
contents: read contents: read
pull-requests: write pull-requests: write
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
with: with:
commit_sha: ${{ github.event.pull_request.head.sha }} commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }} pr_number: ${{ github.event.number }}

View File

@@ -1,3 +1,4 @@
include src/lerobot/templates/lerobot_modelcard_template.md include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
include src/lerobot/datasets/card_template.md include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json include src/lerobot/envs/metaworld_config.json

View File

@@ -39,6 +39,7 @@ from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import ( from lerobot.datasets.video_utils import (
VideoEncoderConfig,
decode_video_frames, decode_video_frames,
encode_video_frames, encode_video_frames,
) )
@@ -251,10 +252,13 @@ def benchmark_encoding_decoding(
imgs_dir=imgs_dir, imgs_dir=imgs_dir,
video_path=video_path, video_path=video_path,
fps=fps, fps=fps,
vcodec=encoding_cfg["vcodec"], camera_encoder_config=VideoEncoderConfig(
pix_fmt=encoding_cfg["pix_fmt"], vcodec=encoding_cfg["vcodec"],
g=encoding_cfg.get("g"), pix_fmt=encoding_cfg["pix_fmt"],
crf=encoding_cfg.get("crf"), g=encoding_cfg.get("g"),
crf=encoding_cfg.get("crf"),
preset=encoding_cfg.get("preset"),
),
# fast_decode=encoding_cfg.get("fastdecode"), # fast_decode=encoding_cfg.get("fastdecode"),
overwrite=True, overwrite=True,
) )

View File

@@ -61,6 +61,8 @@
title: SARM title: SARM
title: "Reward Models" title: "Reward Models"
- sections: - sections:
- local: inference
title: Policy Deployment (lerobot-rollout)
- local: async - local: async
title: Use Async Inference title: Use Async Inference
- local: rtc - local: rtc

View File

@@ -90,6 +90,6 @@ lerobot-record \
--dataset.single_task="Your task description" \ --dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--policy.path=${HF_USER}/act_policy --policy.path=${HF_USER}/act_policy
``` ```

View File

@@ -194,7 +194,7 @@ lerobot-record \
--dataset.single_task="Navigate around obstacles" \ --dataset.single_task="Navigate around obstacles" \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--display_data=true --display_data=true
``` ```

View File

@@ -123,7 +123,7 @@ lerobot-record \
--dataset.single_task="Grab and handover the red cube to the other arm" \ --dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model --policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \ --dataset.episode_time_s=30 \
--dataset.reset_time_s=10 --dataset.reset_time_s=10

View File

@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
### Teleoperator Requirements ### Teleoperator Requirements
The `examples/hil` HIL scripts require **teleoperators with active motors** that can: The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
- Enable/disable torque programmatically - Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing) - Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators in the current `examples/hil` scripts:** **Compatible teleoperators:**
- `openarm_mini` - OpenArm Mini - `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm - `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT] > [!IMPORTANT]
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`. > The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags. > `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
--- ---
## Script ## Script
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`: Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
| Mode | Flag | Models | | Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- | | ------------------------ | ---------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy | | Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA | | Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
--- ---
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
**Standard inference (ACT, Diffusion Policy):** **Standard inference (ACT, Diffusion Policy):**
```bash ```bash
python examples/hil/hil_data_collection.py \ lerobot-rollout --strategy.type=dagger \
--robot.type=bi_openarm_follower \ --robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \ --robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \ --robot.left_arm_config.side=left \
@@ -108,11 +108,10 @@ python examples/hil/hil_data_collection.py \
--teleop.port_left=/dev/ttyACM0 \ --teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \ --teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-dataset \ --dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \ --dataset.fps=30 \
--dataset.episode_time_s=1000 \ --strategy.num_episodes=50 \
--dataset.num_episodes=50 \
--interpolation_multiplier=2 --interpolation_multiplier=2
``` ```
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
For models with high inference latency, enable RTC for smooth execution: For models with high inference latency, enable RTC for smooth execution:
```bash ```bash
python examples/hil/hil_data_collection.py \ lerobot-rollout --strategy.type=dagger \
--rtc.enabled=true \ --inference.type=rtc \
--rtc.execution_horizon=20 \ --inference.rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \ --inference.rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \ --inference.rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \ --robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \ --robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \ --robot.left_arm_config.side=left \
@@ -136,11 +135,10 @@ python examples/hil/hil_data_collection.py \
--teleop.port_left=/dev/ttyACM0 \ --teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \ --teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-rtc-dataset \ --dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \ --dataset.fps=30 \
--dataset.episode_time_s=1000 \ --strategy.num_episodes=50 \
--dataset.num_episodes=50 \
--interpolation_multiplier=3 --interpolation_multiplier=3
``` ```
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here. - **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`. - **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP. - **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.

View File

@@ -232,7 +232,7 @@ lerobot-record \
--dataset.private=true \ --dataset.private=true \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--display_data=true --display_data=true
``` ```
@@ -278,6 +278,6 @@ lerobot-record \
--dataset.num_episodes=10 \ --dataset.num_episodes=10 \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
``` ```

View File

@@ -193,7 +193,7 @@ lerobot-record \
--dataset.num_episodes=5 \ --dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \ --dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--dataset.encoder_threads=2 --dataset.encoder_threads=2
``` ```
</hfoption> </hfoption>
@@ -509,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy ## Run inference and evaluate your policy
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes: Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
<hfoptions id="eval"> <hfoptions id="eval">
<hfoption id="Command"> <hfoption id="Base mode (no recording)">
```bash ```bash
lerobot-record \ lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \ --robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \ --robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \ --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
--robot.id=my_awesome_follower_arm \ --task="Put lego brick into the transparent box" \
--display_data=false \ --duration=60
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_awesome_leader_arm \
--policy.path=${HF_USER}/my_policy
``` ```
</hfoption> </hfoption>
<hfoption id="API example"> <hfoption id="Sentry mode (with recording)">
```bash
<!-- prettier-ignore-start --> lerobot-rollout \
```python --strategy.type=sentry \
from lerobot.cameras.opencv import OpenCVCameraConfig --strategy.upload_every_n_episodes=5 \
from lerobot.datasets import LeRobotDataset --policy.path=${HF_USER}/my_policy \
from lerobot.utils.feature_utils import hw_to_dataset_features --robot.type=so100_follower \
from lerobot.policies.act import ACTPolicy --robot.port=/dev/ttyACM1 \
from lerobot.policies import make_pre_post_processors --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig --dataset.repo_id=${HF_USER}/eval_so100 \
from lerobot.scripts.lerobot_record import record_loop --dataset.single_task="Put lego brick into the transparent box" \
from lerobot.common.control_utils import init_keyboard_listener --duration=600
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
)
# Initialize the robot
robot = SO100Follower(robot_config)
# Initialize the policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Run the policy inference loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
dataset.save_episode()
# Clean up
robot.disconnect()
dataset.push_to_hub()
``` ```
<!-- prettier-ignore-end -->
</hfoption> </hfoption>
</hfoptions> </hfoptions>
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: The `--strategy.type` flag selects the execution mode:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`). - `base`: Autonomous rollout with no data recording (useful for quick evaluation)
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`). - `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).

261
docs/source/inference.mdx Normal file
View File

@@ -0,0 +1,261 @@
# Policy Deployment (lerobot-rollout)
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
## Quick Start
No extra dependencies are needed beyond your robot and policy extras.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=lerobot/act_koch_real \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--task="pick up cube" \
--duration=30
```
This runs the policy for 30 seconds with no recording.
---
## Strategies
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
### Base (`--strategy.type=base`)
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Put lego brick into the box" \
--duration=60
```
| Flag | Description |
| ---------------- | ------------------------------------------------------ |
| `--duration` | Run time in seconds (0 = infinite) |
| `--task` | Task description passed to the policy |
| `--display_data` | Stream observations/actions to Rerun for visualization |
### Sentry (`--strategy.type=sentry`)
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/rollout_eval_data \
--dataset.single_task="Put lego brick into the box" \
--duration=3600
```
| Flag | Description |
| -------------------------------------- | ----------------------------------------------------------- |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
### Highlight (`--strategy.type=highlight`)
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
```bash
lerobot-rollout \
--strategy.type=highlight \
--strategy.ring_buffer_seconds=30 \
--strategy.save_key=s \
--strategy.push_key=h \
--policy.path=${HF_USER}/my_policy \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--dataset.repo_id=${HF_USER}/rollout_highlight_data \
--dataset.single_task="Pick up the red cube"
```
**Keyboard controls:**
| Key | Action |
| ------------------ | -------------------------------------------------------- |
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
| `h` (configurable) | Push dataset to Hub |
| `ESC` | Stop the session |
| Flag | Description |
| -------------------------------------- | ---------------------------------------------- |
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
### DAgger (`--strategy.type=dagger`)
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt"
```
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.record_autonomous=true \
--strategy.num_episodes=50 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/rollout_dagger_data \
--dataset.single_task="Grasp the block"
```
**Keyboard controls** (default input device):
| Key | Action |
| ------- | ------------------------------------------- |
| `Space` | Pause / resume policy execution |
| `Tab` | Start / stop human correction |
| `Enter` | Push dataset to Hub (corrections-only mode) |
| `ESC` | Stop the session |
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
| Flag | Description |
| ------------------------------------ | ------------------------------------------------------- |
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
---
## Inference Backends
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
### Sync (default)
One policy call per control tick. The main loop blocks until the action is computed.
Works with all policies. No extra flags needed.
### Real-Time Chunking (`--inference.type=rtc`)
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
```bash
lerobot-rollout \
--strategy.type=base \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--policy.path=${HF_USER}/pi0_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Pick up the cube" \
--duration=60 \
--device=cuda
```
| Flag | Description |
| ------------------------------------------- | -------------------------------------------------------------- |
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
---
## Common Flags
| Flag | Description | Default |
| --------------------------------- | ----------------------------------------------------------------- | ------- |
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
| `--robot.port` | Serial port for the robot | -- |
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
| `--fps` | Control loop frequency | 30 |
| `--duration` | Run time in seconds (0 = infinite) | 0 |
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
| `--task` | Task description (used when no dataset is provided) | -- |
| `--display_data` | Stream telemetry to Rerun visualization | false |
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
| `--interpolation_multiplier` | Action interpolation factor | 1 |
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
| `--resume` | Resume a previous recording session | false |
| `--play_sounds` | Vocal synthesis for events | true |
---
## Programmatic Usage
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
```python
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
cfg = RolloutConfig(
robot=my_robot_config,
policy=my_policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=30,
duration=60,
task="my task",
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=my_custom_action_processor, # optional
robot_observation_processor=my_custom_obs_processor, # optional
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
```
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.

View File

@@ -43,7 +43,7 @@ lerobot-record \
--dataset.num_episodes=5 \ --dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \ --dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--dataset.encoder_threads=2 --dataset.encoder_threads=2
``` ```

View File

@@ -161,7 +161,7 @@ lerobot-record \
--dataset.private=true \ --dataset.private=true \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--display_data=true --display_data=true
``` ```
@@ -203,7 +203,7 @@ lerobot-record \
--dataset.private=true \ --dataset.private=true \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
--display_data=true --display_data=true
``` ```

View File

@@ -61,17 +61,6 @@ lerobot-eval \
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}' --rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
``` ```
### Recording
`lerobot-record` also supports rename maps, nested under the dataset config:
```bash
lerobot-record \ # When running inference
--policy.path="<user>/smolVLA_finetuned" \
... \
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
```
## Alternative: edit the policy config directly ## Alternative: edit the policy config directly
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed. If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
@@ -105,10 +94,10 @@ XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset
## Quick reference ## Quick reference
| Goal | What to do | | Goal | What to do |
| ----------------------------------------- | --------------------------------------------------------------------------- | | --------------------------------------- | --------------------------------------------------------------------------- |
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` | | Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` | | Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. | | Rollout with different keys (inference) | `--rename_map='{"source_key": "policy_key", ...}'`. |
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) | | Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source | | Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |

View File

@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
### Using RTC with Pi0 ### Using RTC with Pi0
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py). You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline: The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python ```python
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
## Testing RTC with a Real Robot ## Testing RTC with a Real Robot
```bash ```bash
python examples/rtc/eval_with_real_robot.py \ lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USERNAME}/policy_repo_id \ --policy.path=${HF_USERNAME}/policy_repo_id \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--robot.type=so100_follower \ --robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \ --robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
# ... create plots # ... create plots
``` ```
See `examples/rtc/eval_dataset.py` for a complete example of visualization. See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
## References ## References

View File

@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
## Inputs and Targets (What the new code expects) ## Inputs and Targets (What the new code expects)
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which: SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features` - **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`) - **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
<hfoption id="single_stage"> <hfoption id="single_stage">
```bash ```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \ python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \ --dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \ --reward-model-path your-username/sarm-model \
--visualize-only \ --visualize-only \
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
<hfoption id="dense_only"> <hfoption id="dense_only">
```bash ```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \ python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \ --dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \ --reward-model-path your-username/sarm-model \
--visualize-only \ --visualize-only \
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
<hfoption id="dual"> <hfoption id="dual">
```bash ```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \ python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \ --dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \ --reward-model-path your-username/sarm-model \
--visualize-only \ --visualize-only \
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
First, run the SARM model on all frames in your dataset to compute progress values: First, run the SARM model on all frames in your dataset to compute progress values:
```bash ```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \ python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \ --dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \ --reward-model-path your-username/sarm-model \
--head-mode sparse \ --head-mode sparse \
@@ -465,15 +465,15 @@ This script:
### Step 5b: Train Policy with RA-BC ### Step 5b: Train Policy with RA-BC
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash ```bash
lerobot-train \ lerobot-train \
--dataset.repo_id=your-username/your-dataset \ --dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \ --policy.type=pi0 \
--use_rabc=true \ --sample_weighting.type=rabc \
--rabc_head_mode=sparse \ --sample_weighting.head_mode=sparse \
--rabc_kappa=0.01 \ --sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \ --output_dir=outputs/train/policy_rabc \
--batch_size=32 \ --batch_size=32 \
--steps=40000 --steps=40000
@@ -488,12 +488,13 @@ The training script automatically:
**RA-BC Arguments:** **RA-BC Arguments:**
| Argument | Description | Default | | Argument | Description | Default |
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- | | ---------------------------------- | ------------------------------------------------------ | ----------------------- |
| `--use_rabc` | Enable RA-BC sample weighting | `false` | | `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset | | `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | | `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` | | `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
### Tuning RA-BC Kappa ### Tuning RA-BC Kappa
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
Monitor these WandB metrics during training: Monitor these WandB metrics during training:
| Metric | Healthy Range | Problem Indicator | | Metric | Healthy Range | Problem Indicator |
| ------------------ | ------------- | ------------------------- | | ----------------------------- | ------------- | ------------------------- |
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | | `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `rabc_delta_mean` | > 0 | Should be positive | | `sample_weighting/delta_mean` | > 0 | Should be positive |
| `rabc_delta_std` | > 0 | Variance in data quality | | `sample_weighting/delta_std` | > 0 | Variance in data quality |
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. **If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**Setting kappa based on your data:** **Setting kappa based on your data:**
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`: The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
``` ```
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02: # If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
# Most deltas fall in range [0.01, 0.05] # Most deltas fall in range [0.01, 0.05]
# Option 1: Set kappa = delta_mean (medium selectivity) # Option 1: Set kappa = delta_mean (medium selectivity)
--rabc_kappa=0.03 --sample_weighting.kappa=0.03
# Option 2: Set kappa = delta_mean + delta_std (high selectivity) # Option 2: Set kappa = delta_mean + delta_std (high selectivity)
--rabc_kappa=0.05 --sample_weighting.kappa=0.05
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective) # Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
--rabc_kappa=0.07 --sample_weighting.kappa=0.07
``` ```
**When RA-BC may not help:** **When RA-BC may not help:**
@@ -550,8 +551,8 @@ accelerate launch \
src/lerobot/scripts/lerobot_train.py \ src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \ --dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \ --policy.type=pi0 \
--use_rabc=true \ --sample_weighting.type=rabc \
--rabc_kappa=0.01 \ --sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \ --output_dir=outputs/train/policy_rabc \
--batch_size=32 \ --batch_size=32 \
--steps=40000 --steps=40000
@@ -576,7 +577,7 @@ accelerate launch \
### RA-BC ### RA-BC
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality 1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa)) 2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
--- ---

View File

@@ -108,7 +108,7 @@ lerobot-record \
--dataset.num_episodes=10 \ --dataset.num_episodes=10 \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.vcodec=auto \ # --dataset.camera_encoder_config.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \ # <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \ # --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \ # --teleop.port=/dev/ttyACM0 \

View File

@@ -14,12 +14,22 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
## 2. Tuning Parameters ## 2. Tuning Parameters
| Parameter | CLI Flag | Type | Default | Description | All encoding parameters are grouped under `camera_encoder_config` (a `VideoEncoderConfig` dataclass), accessible from the CLI via `--dataset.camera_encoder_config.<field>`.
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture | | Parameter | CLI Flag | Type | Default | Description |
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder | | ----------------------- | --------------------------------------------- | ------------- | ------------- | ------------------------------------------------------------------- |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide | | `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM | | `vcodec` | `--dataset.camera_encoder_config.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `pix_fmt` | `--dataset.camera_encoder_config.pix_fmt` | `str` | `"yuv420p"` | Pixel format |
| `g` | `--dataset.camera_encoder_config.g` | `int \| None` | `2` | GOP size (keyframe interval) |
| `crf` | `--dataset.camera_encoder_config.crf` | `int \| None` | `30` | Quality level (mapped to codec-specific parameter) |
| `preset` | `--dataset.camera_encoder_config.preset` | `int \| None` | `12` | Speed preset (libsvtav1 only, 0 = slowest … 13 = fastest) |
| `fast_decode` | `--dataset.camera_encoder_config.fast_decode` | `int` | `0` | Fast-decode tuning level |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance (global). `None` lets the codec decide |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
> [!TIP]
> Not all parameters apply to every codec. `VideoEncoderConfig` will warn at startup if you set a parameter that your chosen codec ignores (e.g. `preset` with `h264_nvenc`).
## 3. Performance Considerations ## 3. Performance Considerations
@@ -40,7 +50,7 @@ Streaming encoding means the CPU is encoding video **during** the capture loop,
### `encoder_threads` Tuning ### `encoder_threads` Tuning
This parameter controls how many threads each encoder instance uses internally: This parameter (`--dataset.encoder_threads`) controls how many threads each encoder instance uses internally:
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores. - **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs. - **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
@@ -82,15 +92,15 @@ Use HW encoding when:
### Available HW Encoders ### Available HW Encoders
| Encoder | Platform | Hardware | CLI Value | | Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ | | ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------- |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` | | `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` | | `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` | | `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` | | `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` | | `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder_config.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` | | `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder_config.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` | | `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder_config.vcodec=auto` |
> [!NOTE] > [!NOTE]
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers. > In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
@@ -100,15 +110,15 @@ Use HW encoding when:
## 5. Troubleshooting ## 5. Troubleshooting
| Symptom | Likely Cause | Fix | | Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ------------------------------------------------------------------ | -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) | | System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder_config.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). | | "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder_config.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding | | High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows | | Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` | | `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` | | Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder_config.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. | | Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
## 6. Recommended Configurations ## 6. Recommended Configurations
@@ -146,10 +156,10 @@ On very constrained systems, streaming encoding may compete too heavily with the
# 2camsx 640x480x3 @30fps: Requires some tuning. # 2camsx 640x480x3 @30fps: Requires some tuning.
# Use H.264, disable streaming, consider batching encoding # Use H.264, disable streaming, consider batching encoding
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ... lerobot-record --dataset.camera_encoder_config.vcodec=h264 --dataset.streaming_encoding=false ...
``` ```
## 7. Closing note ## 7. Closing note
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration. `camera_encoder_config.vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.

View File

@@ -274,7 +274,8 @@ python src/lerobot/scripts/lerobot_train.py \
Once trained, we recommend deploying policies using inference-time RTC: Once trained, we recommend deploying policies using inference-time RTC:
```bash ```bash
python examples/rtc/eval_with_real_robot.py \ lerobot-rollout \
--strategy.type=base \
--policy.path=your-username/your-repo-id \ --policy.path=your-username/your-repo-id \
--policy.device=cuda \ --policy.device=cuda \
--robot.type=unitree_g1 \ --robot.type=unitree_g1 \
@@ -284,7 +285,7 @@ python examples/rtc/eval_with_real_robot.py \
--task="task_description" \ --task="task_description" \
--duration=1000 \ --duration=1000 \
--fps=30 \ --fps=30 \
--rtc.enabled=true --inference.type=rtc
``` ```
--- ---

View File

@@ -117,10 +117,10 @@ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \ --operation.type convert_image_to_video \
--operation.output_dir outputs/pusht_video \ --operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \ --operation.camera_encoder_config.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \ --operation.camera_encoder_config.pix_fmt yuv420p \
--operation.g 2 \ --operation.camera_encoder_config.g 2 \
--operation.crf 30 --operation.camera_encoder_config.crf 30
# Convert only specific episodes # Convert only specific episodes
lerobot-edit-dataset \ lerobot-edit-dataset \
@@ -147,11 +147,14 @@ lerobot-edit-dataset \
**Parameters:** **Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`) - `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`) - `camera_encoder_config`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder_config.<field>`:
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`) - `vcodec`: Video codec — `h264`, `hevc`, `libsvtav1`, `auto`, or hardware codecs (default: `libsvtav1`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2) - `pix_fmt`: Pixel format — `yuv420p`, `yuv444p` (default: `yuv420p`)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30) - `g`: GOP size — lower values give better quality but larger files (default: 2)
- `fast_decode`: Fast decode tuning option (default: 0) - `crf`: Quality level — lower is better, 0 is lossless (default: 30)
- `preset`: Speed preset, libsvtav1 only (default: 12)
- `fast_decode`: Fast-decode tuning (default: 0)
- `encoder_threads`: Threads per encoder instance — global setting, separate from `camera_encoder_config` (default: None)
- `episode_indices`: List of specific episodes to convert (default: all episodes) - `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4) - `num_workers`: Number of parallel workers for processing (default: 4)

View File

@@ -220,7 +220,7 @@ REAL_DIM = 12
# Postprocessing: Trim 20D predictions to 12D for deployment # Postprocessing: Trim 20D predictions to 12D for deployment
``` ```
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details. See the [action_hub.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py) implementation for details.
#### Auto Action Mode (Recommended) #### Auto Action Mode (Recommended)
@@ -519,9 +519,9 @@ If you use X-VLA in your research, please cite:
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274) - [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
- [LeRobot Documentation](https://github.com/huggingface/lerobot) - [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py) - [Action Registry Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py)
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py) - [Processor Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/processor_xvla.py)
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py) - [Model Configuration](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/configuration_xvla.py)
## Contributing ## Contributing

View File

@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from lerobot.policies.sarm.compute_rabc_weights import ( from lerobot.rewards.sarm.compute_rabc_weights import (
generate_all_frame_indices, generate_all_frame_indices,
interpolate_progress, interpolate_progress,
load_sarm_resources, load_sarm_resources,

File diff suppressed because it is too large Load Diff

View File

@@ -1,226 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common.control_utils import is_headless
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
vcodec: str = "auto"
streaming_encoding: bool = True
encoder_queue_maxsize: int = 30
encoder_threads: int | None = None
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"resume_policy": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
logger.info("[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] Taking control...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "p":
if events["policy_paused"] or events["correction_active"]:
logger.info("[HIL] Resuming policy...")
events["resume_policy"] = True
elif key == keyboard.Key.right:
logger.info("[HIL] End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
logger.info("[HIL] Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
logger.info(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
logger.info("[HIL] RESET")
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
logger.info("Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
logger.info("Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
logger.info(
"%s\n Controls:\n"
" SPACE - Pause policy\n"
" c - Take control\n"
" p - Resume policy after pause/correction\n"
" → - End episode\n"
" ESC - Stop and push to hub",
mode,
)

View File

@@ -14,17 +14,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lerobot.common.control_utils import init_keyboard_listener import logging
import time
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.datasets import LeRobotDataset from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_pre_post_processors from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import make_default_processors from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import hw_to_dataset_features from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 2 NUM_EPISODES = 2
FPS = 30 FPS = 30
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
def main(): def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot # Create the robot configuration & robot
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
@@ -83,43 +90,67 @@ def main():
raise ValueError("Robot is not connected!") raise ValueError("Robot is not connected!")
print("Starting evaluate loop...") print("Starting evaluate loop...")
control_interval = 1 / FPS
recorded_episodes = 0 recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop # Inline evaluation loop: predict actions and send to robot
record_loop( timestamp = 0
robot=robot, start_episode_t = time.perf_counter()
events=events, while timestamp < EPISODE_TIME_SEC:
fps=FPS, start_loop_t = time.perf_counter()
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors if events["exit_early"]:
postprocessor=postprocessor, events["exit_early"] = False
dataset=dataset, break
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, # Get robot observation
display_data=True, obs = robot.get_observation()
teleop_action_processor=teleop_action_processor, obs_processed = robot_observation_processor(obs)
robot_action_processor=robot_action_processor, observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
robot_observation_processor=robot_observation_processor,
) # Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot
robot_action_to_send = robot_action_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording # Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ( if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
): ):
log_say("Reset the environment") log_say("Reset the environment")
record_loop( log_say("Waiting for environment reset, press right arrow key when ready...")
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]: if events["rerecord_episode"]:
log_say("Re-record episode") log_say("Re-record episode")

View File

@@ -45,9 +45,6 @@ def main():
leader_arm = SO100Leader(leader_arm_config) leader_arm = SO100Leader(leader_arm_config)
keyboard = KeyboardTeleop(keyboard_config) keyboard = KeyboardTeleop(keyboard_config)
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Configure the dataset features # Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION) action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
@@ -77,6 +74,10 @@ def main():
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!") raise ValueError("Robot or teleop is not connected!")
teleop_action_processor, robot_action_processor, robot_observation_processor = (
make_default_processors()
)
print("Starting record loop...") print("Starting record loop...")
recorded_episodes = 0 recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
@@ -87,14 +88,14 @@ def main():
robot=robot, robot=robot,
events=events, events=events,
fps=FPS, fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
dataset=dataset, dataset=dataset,
teleop=[leader_arm, keyboard], teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC, control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
) )
# Reset the environment if not stopping or re-recording # Reset the environment if not stopping or re-recording
@@ -106,13 +107,13 @@ def main():
robot=robot, robot=robot,
events=events, events=events,
fps=FPS, fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=[leader_arm, keyboard], teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC, control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
) )
if events["rerecord_episode"]: if events["rerecord_episode"]:

View File

@@ -0,0 +1,77 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained policy on LeKiwi without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). For a CLI entry point with the same capabilities plus
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
"""
from lerobot.configs import PreTrainedConfig
from lerobot.robots.lekiwi import LeKiwiClientConfig
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
# by ``build_rollout_context`` to reload the full model.
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
# Assemble the rollout config: base strategy (no recording) + sync inference.
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
signal_handler = ProcessSignalHandler(use_threads=True)
# Build the context (connects robot, loads policy, wires the inference strategy).
# No custom processors here — LeKiwi runs on raw joint features.
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()

View File

@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.configs import FeatureType, PolicyFeature from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import ( from lerobot.processor import (
RobotProcessorPipeline, RobotProcessorPipeline,
make_default_teleop_action_processor, make_default_teleop_action_processor,
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 5 NUM_EPISODES = 5
FPS = 30 FPS = 30
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main(): def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot # Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig( robot_config = SO100FollowerConfig(
@@ -143,43 +151,67 @@ def main():
raise ValueError("Robot is not connected!") raise ValueError("Robot is not connected!")
print("Starting evaluate loop...") print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0 episode_idx = 0
for episode_idx in range(NUM_EPISODES): for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop # Inline evaluation loop: predict actions and send to robot
record_loop( timestamp = 0
robot=robot, start_episode_t = time.perf_counter()
events=events, while timestamp < EPISODE_TIME_SEC:
fps=FPS, start_loop_t = time.perf_counter()
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors if events["exit_early"]:
postprocessor=postprocessor, events["exit_early"] = False
dataset=dataset, break
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, # Get robot observation
display_data=True, obs = robot.get_observation()
teleop_action_processor=make_default_teleop_action_processor(), obs_processed = robot_joints_to_ee_pose_processor(obs)
robot_action_processor=robot_ee_to_joints_processor, observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
robot_observation_processor=robot_joints_to_ee_pose_processor,
) # Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording # Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ( if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
): ):
log_say("Reset the environment") log_say("Reset the environment")
record_loop( log_say("Waiting for environment reset, press right arrow key when ready...")
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]: if events["rerecord_episode"]:
log_say("Re-record episode") log_say("Re-record episode")
@@ -190,7 +222,6 @@ def main():
# Save episode # Save episode
dataset.save_episode() dataset.save_episode()
episode_idx += 1
finally: finally:
# Clean up # Clean up
log_say("Stop recording") log_say("Stop recording")

View File

@@ -65,14 +65,15 @@ def main():
robot = SO100Follower(robot_config) robot = SO100Follower(robot_config)
phone = Phone(teleop_config) phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics( kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf", urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link", target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()), joint_names=list(robot.bus.motors.keys()),
) )
# Build pipeline to convert phone action to EE action # Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
tuple[RobotAction, RobotObservation], RobotAction tuple[RobotAction, RobotObservation], RobotAction
]( ](
@@ -94,7 +95,7 @@ def main():
to_output=transition_to_robot_action, to_output=transition_to_robot_action,
) )
# Build pipeline to convert EE action to joints action # Build pipeline to convert EE action to joints action (IK).
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[ steps=[
InverseKinematicsEEToJoints( InverseKinematicsEEToJoints(
@@ -107,7 +108,7 @@ def main():
to_output=transition_to_robot_action, to_output=transition_to_robot_action,
) )
# Build pipeline to convert joint observation to EE observation # Build pipeline to convert joint observation to EE observation (FK).
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ steps=[
ForwardKinematicsJointsToEE( ForwardKinematicsJointsToEE(
@@ -118,13 +119,12 @@ def main():
to_output=transition_to_observation, to_output=transition_to_observation,
) )
# Create the dataset # Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID, repo_id=HF_REPO_ID,
fps=FPS, fps=FPS,
features=combine_feature_dicts( features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features( aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose_processor, pipeline=phone_to_robot_ee_pose_processor,
initial_features=create_initial_features(action=phone.action_features), initial_features=create_initial_features(action=phone.action_features),
@@ -163,14 +163,14 @@ def main():
robot=robot, robot=robot,
events=events, events=events,
fps=FPS, fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone, teleop=phone,
dataset=dataset, dataset=dataset,
control_time_s=EPISODE_TIME_SEC, control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
) )
# Reset the environment if not stopping or re-recording # Reset the environment if not stopping or re-recording
@@ -182,13 +182,13 @@ def main():
robot=robot, robot=robot,
events=events, events=events,
fps=FPS, fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone, teleop=phone,
control_time_s=RESET_TIME_SEC, control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
) )
if events["rerecord_episode"]: if events["rerecord_episode"]:

View File

@@ -0,0 +1,126 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
with phone teleoperation in EE space, so at deployment we only need the
joint↔EE conversion on the robot side; the phone is not used.
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
(inline policy call). For recording during rollout, switch to Sentry,
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Peek at motor names once to build the kinematic solver.
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()

View File

@@ -1,673 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot-data-collection/folding_final \
--robot.type=bi_openarm_follower \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.left_arm_config.can_interface=socketcan \
--robot.left_arm_config.disable_torque_on_disconnect=true \
--robot.left_arm_config.max_relative_target=8.0 \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.right_arm_config.can_interface=socketcan \
--robot.right_arm_config.disable_torque_on_disconnect=true \
--robot.right_arm_config.max_relative_target=8.0 \
--task="Fold the T-shirt properly" \
--fps=30 \
--duration=2000 \
--interpolation_multiplier=3 \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--device=cuda
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
make_default_robot_action_processor,
make_default_robot_observation_processor,
to_relative_actions,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: Tensor,
current_state: Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> Tensor:
"""Convert absolute leftovers into model-space for relative-action RTC policies.
When a policy uses relative actions, the RTC prefix (leftover actions from
the previous chunk) is stored in absolute space. Before feeding it back to
the policy we need to re-express it relative to the *current* robot state
and then re-normalize.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
# Only keep .pos joints + camera streams if the policy was trained on positions,
# not the full pos/vel/torque state the robot exposes.
observation_features_hw = {
key: value
for key, value in robot.observation_features().items()
if key.endswith(".pos") or isinstance(value, tuple)
}
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if relative_step is not None:
if relative_step.action_names is None:
cfg_names = getattr(cfg.policy, "action_feature_names", None)
if cfg_names:
relative_step.action_names = list(cfg_names)
else:
relative_step.action_names = [
k for k in robot.robot.action_features if k.endswith(".pos")
]
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Re-anchor leftover actions for relative-action policies.
# We need the *postprocessed* (absolute) leftover, not the original
# (normalized/relative) one that get_left_over() returns.
if (
prev_actions is not None
and relative_step is not None
and OBS_STATE in obs_with_policy_features
):
with action_queue.lock:
if action_queue.queue is not None:
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
else:
prev_actions_abs = None
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_actions_abs,
current_state=obs_with_policy_features[OBS_STATE],
relative_step=relative_step,
normalizer_step=normalizer_step,
policy_device=policy_device,
)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
action_count = 0
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
if config.use_peft:
from peft import PeftConfig, PeftModel
peft_pretrained_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")

View File

@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.configs import FeatureType, PolicyFeature from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import ( from lerobot.processor import (
RobotProcessorPipeline, RobotProcessorPipeline,
make_default_teleop_action_processor, make_default_teleop_action_processor,
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 5 NUM_EPISODES = 5
FPS = 30 FPS = 30
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main(): def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot # Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig( robot_config = SO100FollowerConfig(
@@ -143,43 +151,67 @@ def main():
raise ValueError("Robot is not connected!") raise ValueError("Robot is not connected!")
print("Starting evaluate loop...") print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0 episode_idx = 0
for episode_idx in range(NUM_EPISODES): for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop # Inline evaluation loop: predict actions and send to robot
record_loop( timestamp = 0
robot=robot, start_episode_t = time.perf_counter()
events=events, while timestamp < EPISODE_TIME_SEC:
fps=FPS, start_loop_t = time.perf_counter()
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors if events["exit_early"]:
postprocessor=postprocessor, events["exit_early"] = False
dataset=dataset, break
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, # Get robot observation
display_data=True, obs = robot.get_observation()
teleop_action_processor=make_default_teleop_action_processor(), obs_processed = robot_joints_to_ee_pose_processor(obs)
robot_action_processor=robot_ee_to_joints_processor, observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
robot_observation_processor=robot_joints_to_ee_pose_processor,
) # Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording # Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ( if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
): ):
log_say("Reset the environment") log_say("Reset the environment")
record_loop( log_say("Waiting for environment reset, press right arrow key when ready...")
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]: if events["rerecord_episode"]:
log_say("Re-record episode") log_say("Re-record episode")
@@ -190,7 +222,6 @@ def main():
# Save episode # Save episode
dataset.save_episode() dataset.save_episode()
episode_idx += 1
finally: finally:
# Clean up # Clean up
log_say("Stop recording") log_say("Stop recording")

View File

@@ -62,21 +62,20 @@ def main():
follower = SO100Follower(follower_config) follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config) leader = SO100Leader(leader_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics( follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf", urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link", target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()), joint_names=list(follower.bus.motors.keys()),
) )
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics( leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf", urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link", target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()), joint_names=list(leader.bus.motors.keys()),
) )
# Build pipeline to convert follower joints to EE observation # Build pipeline to convert follower joints to EE observation.
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation]( follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ steps=[
ForwardKinematicsJointsToEE( ForwardKinematicsJointsToEE(
@@ -87,7 +86,7 @@ def main():
to_output=transition_to_observation, to_output=transition_to_observation,
) )
# Build pipeline to convert leader joints to EE action # Build pipeline to convert leader joints to EE action.
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[ steps=[
ForwardKinematicsJointsToEE( ForwardKinematicsJointsToEE(
@@ -98,9 +97,9 @@ def main():
to_output=transition_to_robot_action, to_output=transition_to_robot_action,
) )
# Build pipeline to convert EE action to follower joints # Build pipeline to convert EE action to follower joints (with safety bounds).
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
[ steps=[
EEBoundsAndSafety( EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10, max_ee_step_m=0.10,
@@ -115,13 +114,12 @@ def main():
to_output=transition_to_robot_action, to_output=transition_to_robot_action,
) )
# Create the dataset # Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID, repo_id=HF_REPO_ID,
fps=FPS, fps=FPS,
features=combine_feature_dicts( features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features( aggregate_pipeline_dataset_features(
pipeline=leader_joints_to_ee, pipeline=leader_joints_to_ee,
initial_features=create_initial_features(action=leader.action_features), initial_features=create_initial_features(action=leader.action_features),
@@ -144,7 +142,7 @@ def main():
# Initialize the keyboard listener and rerun visualization # Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone") init_rerun(session_name="recording_so100_ee")
try: try:
if not leader.is_connected or not follower.is_connected: if not leader.is_connected or not follower.is_connected:
@@ -160,14 +158,14 @@ def main():
robot=follower, robot=follower,
events=events, events=events,
fps=FPS, fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader, teleop=leader,
dataset=dataset, dataset=dataset,
control_time_s=EPISODE_TIME_SEC, control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
) )
# Reset the environment if not stopping or re-recording # Reset the environment if not stopping or re-recording
@@ -179,13 +177,13 @@ def main():
robot=follower, robot=follower,
events=events, events=events,
fps=FPS, fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader, teleop=leader,
control_time_s=RESET_TIME_SEC, control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
) )
if events["rerecord_episode"]: if events["rerecord_episode"]:

View File

@@ -0,0 +1,134 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). The custom observation/action processors convert between
joint space (robot hardware) and end-effector space (policy I/O) via
forward/inverse kinematics.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Kinematic solver: we need the motor-name list, so peek at the robot once.
# (The rollout engine owns the connected instance; we only use this for introspection.)
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
# Joint-space observation → EE-space observation (consumed by the policy).
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# EE-space action (produced by the policy) → joint-space action (sent to robot).
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Policy config (full model is loaded inside build_rollout_context).
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
# otherwise skip the joint↔EE conversion and the policy would receive the
# wrong observation/action space.
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()

View File

@@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies import SACConfig from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier from lerobot.rewards.classifier.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig from lerobot.robots.so_follower import SO100FollowerConfig

View File

@@ -1,7 +1,7 @@
import torch import torch
from lerobot.datasets import LeRobotDataset from lerobot.datasets import LeRobotDataset
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
def main(): def main():
@@ -22,10 +22,10 @@ def main():
model_name="microsoft/resnet-18", model_name="microsoft/resnet-18",
) )
# Make policy, preprocessor, and optimizer # Make reward model, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta) reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
optimizer = config.get_optimizer_preset().build(policy.parameters()) optimizer = config.get_optimizer_preset().build(reward_model.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats) preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
classifier_id = "<user>/reward_classifier_hil_serl_example" classifier_id = "<user>/reward_classifier_hil_serl_example"
@@ -42,7 +42,7 @@ def main():
batch = preprocessor(batch) batch = preprocessor(batch)
# Forward pass # Forward pass
loss, output_dict = policy.forward(batch) loss, output_dict = reward_model.forward(batch)
# Backward pass and optimization # Backward pass and optimization
optimizer.zero_grad() optimizer.zero_grad()
@@ -58,8 +58,8 @@ def main():
print("Training finished!") print("Training finished!")
# You can now save the trained policy. # You can now save the trained reward model.
policy.push_to_hub(classifier_id) reward_model.push_to_hub(classifier_id)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -289,6 +289,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ---------------- # ---------------- Tool Configurations ----------------
[tool.setuptools.package-data] [tool.setuptools.package-data]

View File

@@ -17,6 +17,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
""" """
import logging import logging
import sys
import time import time
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -41,6 +42,7 @@ from ..utils import get_cv2_rotation
from .configuration_realsense import RealSenseCameraConfig from .configuration_realsense import RealSenseCameraConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
pkg_name = "pyrealsense2-macosx" if sys.platform == "darwin" else "pyrealsense2"
class RealSenseCamera(Camera): class RealSenseCamera(Camera):
@@ -114,7 +116,7 @@ class RealSenseCamera(Camera):
Args: Args:
config: The configuration settings for the camera. config: The configuration settings for the camera.
""" """
require_package("pyrealsense2", extra="intelrealsense") require_package(pkg_name, extra="intelrealsense", import_name="pyrealsense2")
super().__init__(config) super().__init__(config)
self.config = config self.config = config

View File

@@ -41,8 +41,12 @@ def cfg_to_group(
return tag return tag
return tag[:max_tag_length] return tag[:max_tag_length]
if cfg.is_reward_model_training:
trainable_tag = f"reward_model:{cfg.reward_model.type}"
else:
trainable_tag = f"policy:{cfg.policy.type}"
lst = [ lst = [
f"policy:{cfg.policy.type}", trainable_tag,
f"seed:{cfg.seed}", f"seed:{cfg.seed}",
] ]
if cfg.dataset is not None: if cfg.dataset is not None:

View File

@@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig`` Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
""" """
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig from .policies import PreTrainedConfig
from .types import ( from .types import (
@@ -39,6 +40,7 @@ __all__ = [
"PolicyFeature", "PolicyFeature",
"RTCAttentionSchedule", "RTCAttentionSchedule",
# Config classes # Config classes
"DatasetRecordConfig",
"DatasetConfig", "DatasetConfig",
"EvalConfig", "EvalConfig",
"PeftConfig", "PeftConfig",

View File

@@ -0,0 +1,80 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str = ""
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str = ""
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
# Maximum number of frames to buffer per camera when using streaming encoding.
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
encoder_queue_maxsize: int = 30
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
def stamp_repo_id(self) -> None:
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
Must be called explicitly at dataset *creation* time — not on resume,
where the existing ``repo_id`` (already stamped) must be preserved.
"""
if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}"

View File

@@ -17,7 +17,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.transforms import ImageTransformsConfig from lerobot.transforms import ImageTransformsConfig
from lerobot.utils.import_utils import get_safe_default_codec from lerobot.utils.import_utils import get_safe_default_video_backend
@dataclass @dataclass
@@ -34,7 +34,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None revision: str | None = None
use_imagenet_stats: bool = True use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec) video_backend: str = field(default_factory=get_safe_default_video_backend)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0). # When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False return_uint8: bool = False

View File

@@ -0,0 +1,163 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import builtins
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="RewardModelConfig")
logger = logging.getLogger(__name__)
@dataclass
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""Base configuration for reward models.
Args:
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
"""
# Reuses PolicyFeature
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None
pretrained_path: str | None = None
push_to_hub: bool = False
repo_id: str | None = None
# Hub metadata
license: str | None = None
tags: list[str] | None = None
private: bool | None = None
def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
@property
def type(self) -> str:
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@property
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@property
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@property
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
pass
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**reward_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
# HACK: Parse the original config to get the config subclass, so that we can
# apply cli overrides.
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])
with open(config_file) as f:
config = json.load(f)
config.pop("type", None)
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(config, f)
config_file = f.name
cli_overrides = reward_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)

View File

@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import builtins import builtins
import datetime as dt import datetime as dt
import json
import os import os
import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -26,18 +28,57 @@ from lerobot import envs
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.optim import LRSchedulerConfig, OptimizerConfig from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
TRAIN_CONFIG_NAME = "train_config.json" TRAIN_CONFIG_NAME = "train_config.json"
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
legacy_fields = (
"use_rabc",
"rabc_progress_path",
"rabc_kappa",
"rabc_epsilon",
"rabc_head_mode",
)
if not any(key in config for key in legacy_fields):
return None
migrated_config = dict(config)
use_rabc = bool(migrated_config.pop("use_rabc", False))
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
rabc_kappa = migrated_config.pop("rabc_kappa", None)
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
# New configs may already define sample_weighting explicitly. In that case,
# legacy fields are ignored after being stripped from the payload.
if migrated_config.get("sample_weighting") is None and use_rabc:
sample_weighting: dict[str, Any] = {"type": "rabc"}
if rabc_progress_path is not None:
sample_weighting["progress_path"] = rabc_progress_path
if rabc_kappa is not None:
sample_weighting["kappa"] = rabc_kappa
if rabc_epsilon is not None:
sample_weighting["epsilon"] = rabc_epsilon
if rabc_head_mode is not None:
sample_weighting["head_mode"] = rabc_head_mode
migrated_config["sample_weighting"] = sample_weighting
return migrated_config
@dataclass @dataclass
class TrainPipelineConfig(HubMixin): class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig dataset: DatasetConfig
env: envs.EnvConfig | None = None env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
reward_model: RewardModelConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true. # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None output_dir: Path | None = None
@@ -72,27 +113,41 @@ class TrainPipelineConfig(HubMixin):
wandb: WandBConfig = field(default_factory=WandBConfig) wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None peft: PeftConfig | None = None
# RA-BC (Reward-Aligned Behavior Cloning) parameters # Sample weighting configuration (e.g., for RA-BC training)
use_rabc: bool = False # Enable reward-weighted training sample_weighting: SampleWeightingConfig | None = None
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
# Rename map for the observation to override the image and state keys # Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict) rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None) checkpoint_path: Path | None = field(init=False, default=None)
@property
def is_reward_model_training(self) -> bool:
"""True when the config targets a reward model rather than a policy."""
return self.reward_model is not None
@property
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
"""Return whichever config (policy or reward_model) is active."""
if self.is_reward_model_training:
return self.reward_model # type: ignore[return-value]
return self.policy # type: ignore[return-value]
def validate(self) -> None: def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some. # HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy") policy_path = parser.get_path_arg("policy")
if policy_path: reward_model_path = parser.get_path_arg("reward_model")
# Only load the policy config
if reward_model_path:
cli_overrides = parser.get_cli_overrides("reward_model")
self.reward_model = RewardModelConfig.from_pretrained(
reward_model_path, cli_overrides=cli_overrides
)
self.reward_model.pretrained_path = str(Path(reward_model_path))
elif policy_path:
cli_overrides = parser.get_cli_overrides("policy") cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path) self.policy.pretrained_path = Path(policy_path)
elif self.resume: elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path") config_path = parser.parse_arg("config_path")
if not config_path: if not config_path:
raise ValueError( raise ValueError(
@@ -108,18 +163,22 @@ class TrainPipelineConfig(HubMixin):
policy_dir = Path(config_path).parent policy_dir = Path(config_path).parent
if self.policy is not None: if self.policy is not None:
self.policy.pretrained_path = policy_dir self.policy.pretrained_path = policy_dir
if self.reward_model is not None:
self.reward_model.pretrained_path = str(policy_dir)
self.checkpoint_path = policy_dir.parent self.checkpoint_path = policy_dir.parent
if self.policy is None: if self.policy is None and self.reward_model is None:
raise ValueError( raise ValueError(
"Policy is not configured. Please specify a pretrained policy with `--policy.path`." "Neither policy nor reward_model is configured. "
"Please specify one with `--policy.path` or `--reward_model.path`."
) )
active_cfg = self.trainable_config
if not self.job_name: if not self.job_name:
if self.env is None: if self.env is None:
self.job_name = f"{self.policy.type}" self.job_name = f"{active_cfg.type}"
else: else:
self.job_name = f"{self.env.type}_{self.policy.type}" self.job_name = f"{self.env.type}_{active_cfg.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir(): if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError( raise FileExistsError(
@@ -137,26 +196,16 @@ class TrainPipelineConfig(HubMixin):
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
elif self.use_policy_training_preset and not self.resume: elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset() self.optimizer = active_cfg.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset() self.scheduler = active_cfg.get_scheduler_preset()
if self.policy.push_to_hub and not self.policy.repo_id: if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
raise ValueError( raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
else:
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
@classmethod @classmethod
def __get_path_fields__(cls) -> list[str]: def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`""" """Keys for draccus pretrained-path loading."""
return ["policy"] return ["policy", "reward_model"]
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
@@ -207,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
) from e ) from e
cli_args = kwargs.pop("cli_args", []) cli_args = kwargs.pop("cli_args", [])
if config_file is not None:
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
if migrated_config is not None:
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(migrated_config, f)
config_file = f.name
with draccus.config_type("json"): with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args) return draccus.parse(cls, config_file, args=cli_args)

View File

@@ -40,10 +40,19 @@ from .io_utils import load_episodes, write_stats
from .lerobot_dataset import LeRobotDataset from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import (
check_video_encoder_config_pyav,
detect_available_encoders_pyav,
get_codec,
)
from .sampler import EpisodeAwareSampler from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager from .video_utils import (
VideoEncoderConfig,
VideoEncodingManager,
camera_encoder_defaults,
)
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.) # NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
# and legacy migration constants are intentionally NOT re-exported here. # and legacy migration constants are intentionally NOT re-exported here.
@@ -58,15 +67,20 @@ __all__ = [
"LeRobotDatasetMetadata", "LeRobotDatasetMetadata",
"MultiLeRobotDataset", "MultiLeRobotDataset",
"StreamingLeRobotDataset", "StreamingLeRobotDataset",
"VideoEncoderConfig",
"VideoEncodingManager", "VideoEncodingManager",
"camera_encoder_defaults",
"add_features", "add_features",
"aggregate_datasets", "aggregate_datasets",
"aggregate_pipeline_dataset_features", "aggregate_pipeline_dataset_features",
"aggregate_stats", "aggregate_stats",
"check_video_encoder_config_pyav",
"convert_image_to_video_dataset", "convert_image_to_video_dataset",
"create_initial_features", "create_initial_features",
"create_lerobot_dataset_card", "create_lerobot_dataset_card",
"delete_episodes", "delete_episodes",
"detect_available_encoders_pyav",
"get_codec",
"get_feature_stats", "get_feature_stats",
"load_episodes", "load_episodes",
"make_dataset", "make_dataset",

View File

@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
pd.DataFrame: Updated DataFrame with adjusted indices. pd.DataFrame: Updated DataFrame with adjusted indices.
""" """
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"] df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
df["index"] = df["index"] + dst_meta.info["total_frames"] df["index"] = df["index"] + dst_meta.info.total_frames
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy()) src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy() df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
@@ -225,9 +225,9 @@ def update_meta_data(
# Clean up temporary columns # Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"]) df = df.drop(columns=["_orig_chunk", "_orig_file"])
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"] df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"] df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"] df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
return df return df
@@ -237,8 +237,8 @@ def aggregate_datasets(
aggr_repo_id: str, aggr_repo_id: str,
roots: list[Path] | None = None, roots: list[Path] | None = None,
aggr_root: Path | None = None, aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None, data_files_size_in_mb: int | None = None,
video_files_size_in_mb: float | None = None, video_files_size_in_mb: int | None = None,
chunk_size: int | None = None, chunk_size: int | None = None,
): ):
"""Aggregates multiple LeRobot datasets into a single unified dataset. """Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -313,8 +313,8 @@ def aggregate_datasets(
# to avoid interference between different source datasets # to avoid interference between different source datasets
data_idx.pop("src_to_dst", None) data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info.total_episodes += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames dst_meta.info.total_frames += src_meta.total_frames
finalize_aggregation(dst_meta, all_metadata) finalize_aggregation(dst_meta, all_metadata)
logging.info("Aggregation complete.") logging.info("Aggregation complete.")
@@ -332,7 +332,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
videos_idx: Dictionary tracking video chunk and file indices. videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns: Returns:
dict: Updated videos_idx with current chunk and file indices. dict: Updated videos_idx with current chunk and file indices.
""" """
@@ -417,6 +416,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
concatenate_video_files( concatenate_video_files(
[dst_path, src_path], [dst_path, src_path],
dst_path, dst_path,
compatibility_check=True,
) )
# Update duration of this destination file # Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration dst_file_durations[dst_key] = current_dst_duration + src_duration
@@ -640,14 +640,10 @@ def finalize_aggregation(aggr_meta, all_metadata):
write_tasks(aggr_meta.tasks, aggr_meta.root) write_tasks(aggr_meta.tasks, aggr_meta.root)
logging.info("write info") logging.info("write info")
aggr_meta.info.update( aggr_meta.info.total_tasks = len(aggr_meta.tasks)
{ aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
"total_tasks": len(aggr_meta.tasks), aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
"total_episodes": sum(m.total_episodes for m in all_metadata), aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
"total_frames": sum(m.total_frames for m in all_metadata),
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
}
)
write_info(aggr_meta.info, aggr_meta.root) write_info(aggr_meta.info, aggr_meta.root)
logging.info("write stats") logging.info("write stats")

View File

@@ -37,20 +37,18 @@ from .io_utils import (
load_subtasks, load_subtasks,
load_tasks, load_tasks,
write_info, write_info,
write_json,
write_stats, write_stats,
write_tasks, write_tasks,
) )
from .utils import ( from .utils import (
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
INFO_PATH,
check_version_compatibility, check_version_compatibility,
get_safe_version, get_safe_version,
has_legacy_hub_download_metadata, has_legacy_hub_download_metadata,
is_valid_version, is_valid_version,
update_chunk_file_indices, update_chunk_file_indices,
) )
from .video_utils import get_video_info from .video_utils import VideoEncoderConfig, get_video_info
CODEBASE_VERSION = "v3.0" CODEBASE_VERSION = "v3.0"
@@ -228,7 +226,7 @@ class LeRobotDatasetMetadata:
@property @property
def _version(self) -> packaging.version.Version: def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset.""" """Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"]) return packaging.version.parse(self.info.codebase_version)
def get_data_file_path(self, ep_index: int) -> Path: def get_data_file_path(self, ep_index: int) -> Path:
"""Return the relative parquet file path for the given episode index. """Return the relative parquet file path for the given episode index.
@@ -283,27 +281,27 @@ class LeRobotDatasetMetadata:
@property @property
def data_path(self) -> str: def data_path(self) -> str:
"""Formattable string for the parquet files.""" """Formattable string for the parquet files."""
return self.info["data_path"] return self.info.data_path
@property @property
def video_path(self) -> str | None: def video_path(self) -> str | None:
"""Formattable string for the video files.""" """Formattable string for the video files."""
return self.info["video_path"] return self.info.video_path
@property @property
def robot_type(self) -> str | None: def robot_type(self) -> str | None:
"""Robot type used in recording this dataset.""" """Robot type used in recording this dataset."""
return self.info["robot_type"] return self.info.robot_type
@property @property
def fps(self) -> int: def fps(self) -> int:
"""Frames per second used during data collection.""" """Frames per second used during data collection."""
return self.info["fps"] return self.info.fps
@property @property
def features(self) -> dict[str, dict]: def features(self) -> dict[str, dict]:
"""All features contained in the dataset.""" """All features contained in the dataset."""
return self.info["features"] return self.info.features
@property @property
def image_keys(self) -> list[str]: def image_keys(self) -> list[str]:
@@ -333,32 +331,32 @@ class LeRobotDatasetMetadata:
@property @property
def total_episodes(self) -> int: def total_episodes(self) -> int:
"""Total number of episodes available.""" """Total number of episodes available."""
return self.info["total_episodes"] return self.info.total_episodes
@property @property
def total_frames(self) -> int: def total_frames(self) -> int:
"""Total number of frames saved in this dataset.""" """Total number of frames saved in this dataset."""
return self.info["total_frames"] return self.info.total_frames
@property @property
def total_tasks(self) -> int: def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset.""" """Total number of different tasks performed in this dataset."""
return self.info["total_tasks"] return self.info.total_tasks
@property @property
def chunks_size(self) -> int: def chunks_size(self) -> int:
"""Max number of files per chunk.""" """Max number of files per chunk."""
return self.info["chunks_size"] return self.info.chunks_size
@property @property
def data_files_size_in_mb(self) -> int: def data_files_size_in_mb(self) -> int:
"""Max size of data file in mega bytes.""" """Max size of data file in mega bytes."""
return self.info["data_files_size_in_mb"] return self.info.data_files_size_in_mb
@property @property
def video_files_size_in_mb(self) -> int: def video_files_size_in_mb(self) -> int:
"""Max size of video file in mega bytes.""" """Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"] return self.info.video_files_size_in_mb
def get_task_index(self, task: str) -> int | None: def get_task_index(self, task: str) -> int | None:
""" """
@@ -502,20 +500,33 @@ class LeRobotDatasetMetadata:
self._save_episode_metadata(episode_dict) self._save_episode_metadata(episode_dict)
# Update info # Update info
self.info["total_episodes"] += 1 self.info.total_episodes += 1
self.info["total_frames"] += episode_length self.info.total_frames += episode_length
self.info["total_tasks"] = len(self.tasks) self.info.total_tasks = len(self.tasks)
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info.splits = {"train": f"0:{self.info.total_episodes}"}
write_info(self.info, self.root) write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root) write_stats(self.stats, self.root)
def update_video_info(self, video_key: str | None = None) -> None: def update_video_info(
""" self,
video_key: str | None = None,
camera_encoder_config: VideoEncoderConfig | None = None,
) -> None:
"""Populate per-feature video info in ``info.json``.
Warning: this function writes info from first episode videos, implicitly assuming that all videos have Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists. been encoded the same way. Also, this means it assumes the first episode exists.
Args:
video_key: If provided, only update this video key. Otherwise update
all video keys in the dataset.
camera_encoder_config: Encoder configuration used to produce the
videos. When provided, its fields are recorded as
``video.<field>`` entries alongside the stream-derived
``video.*`` entries (see :func:`get_video_info`).
""" """
if video_key is not None and video_key not in self.video_keys: if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset") raise ValueError(f"Video key {video_key} not found in dataset")
@@ -524,7 +535,9 @@ class LeRobotDatasetMetadata:
for key in video_keys: for key in video_keys:
if not self.features[key].get("info", None): if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path) self.info.features[key]["info"] = get_video_info(
video_path, camera_encoder_config=camera_encoder_config
)
def update_chunk_settings( def update_chunk_settings(
self, self,
@@ -546,17 +559,17 @@ class LeRobotDatasetMetadata:
if chunks_size is not None: if chunks_size is not None:
if chunks_size <= 0: if chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {chunks_size}") raise ValueError(f"chunks_size must be positive, got {chunks_size}")
self.info["chunks_size"] = chunks_size self.info.chunks_size = chunks_size
if data_files_size_in_mb is not None: if data_files_size_in_mb is not None:
if data_files_size_in_mb <= 0: if data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
self.info["data_files_size_in_mb"] = data_files_size_in_mb self.info.data_files_size_in_mb = data_files_size_in_mb
if video_files_size_in_mb is not None: if video_files_size_in_mb is not None:
if video_files_size_in_mb <= 0: if video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb self.info.video_files_size_in_mb = video_files_size_in_mb
# Update the info file on disk # Update the info file on disk
write_info(self.info, self.root) write_info(self.info, self.root)
@@ -653,7 +666,7 @@ class LeRobotDatasetMetadata:
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
"Either remove video features from the features dict, or set 'use_videos=True'." "Either remove video features from the features dict, or set 'use_videos=True'."
) )
write_json(obj.info, obj.root / INFO_PATH) write_info(obj.info, obj.root)
obj.revision = None obj.revision = None
obj._pq_writer = None obj._pq_writer = None
obj.latest_episode = None obj.latest_episode = None

View File

@@ -62,7 +62,12 @@ from .utils import (
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
update_chunk_file_indices, update_chunk_file_indices,
) )
from .video_utils import encode_video_frames, get_video_info from .video_utils import (
VideoEncoderConfig,
camera_encoder_defaults,
encode_video_frames,
get_video_info,
)
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -92,6 +97,7 @@ def delete_episodes(
episode_indices: list[int], episode_indices: list[int],
output_dir: str | Path | None = None, output_dir: str | Path | None = None,
repo_id: str | None = None, repo_id: str | None = None,
camera_encoder_config: VideoEncoderConfig | None = None,
) -> LeRobotDataset: ) -> LeRobotDataset:
"""Delete episodes from a LeRobotDataset and create a new dataset. """Delete episodes from a LeRobotDataset and create a new dataset.
@@ -100,6 +106,8 @@ def delete_episodes(
episode_indices: List of episode indices to delete. episode_indices: List of episode indices to delete.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
camera_encoder_config: Video encoder settings used when re-encoding video segments
(``None`` uses :func:`~lerobot.datasets.video_utils.camera_encoder_defaults`).
""" """
if not episode_indices: if not episode_indices:
raise ValueError("No episodes to delete") raise ValueError("No episodes to delete")
@@ -132,7 +140,7 @@ def delete_episodes(
video_metadata = None video_metadata = None
if dataset.meta.video_keys: if dataset.meta.video_keys:
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping, camera_encoder_config)
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
@@ -154,6 +162,7 @@ def split_dataset(
dataset: LeRobotDataset, dataset: LeRobotDataset,
splits: dict[str, float | list[int]], splits: dict[str, float | list[int]],
output_dir: str | Path | None = None, output_dir: str | Path | None = None,
camera_encoder_config: VideoEncoderConfig | None = None,
) -> dict[str, LeRobotDataset]: ) -> dict[str, LeRobotDataset]:
"""Split a LeRobotDataset into multiple smaller datasets. """Split a LeRobotDataset into multiple smaller datasets.
@@ -162,6 +171,8 @@ def split_dataset(
splits: Either a dict mapping split names to episode indices, or a dict mapping splits: Either a dict mapping split names to episode indices, or a dict mapping
split names to fractions (must sum to <= 1.0). split names to fractions (must sum to <= 1.0).
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
camera_encoder_config: Video encoder settings used when re-encoding video segments
(``None`` uses :func:`~lerobot.datasets.video_utils.camera_encoder_defaults`).
Examples: Examples:
Split by specific episodes Split by specific episodes
@@ -222,7 +233,9 @@ def split_dataset(
video_metadata = None video_metadata = None
if dataset.meta.video_keys: if dataset.meta.video_keys:
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) video_metadata = _copy_and_reindex_videos(
dataset, new_meta, episode_mapping, camera_encoder_config
)
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
@@ -578,8 +591,7 @@ def _keep_episodes_from_video_with_av(
output_path: Path, output_path: Path,
episodes_to_keep: list[tuple[int, int]], episodes_to_keep: list[tuple[int, int]],
fps: float, fps: float,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
pix_fmt: str = "yuv420p",
) -> None: ) -> None:
"""Keep only specified episodes from a video file using PyAV. """Keep only specified episodes from a video file using PyAV.
@@ -593,9 +605,11 @@ def _keep_episodes_from_video_with_av(
Ranges are half-open intervals: [start_frame, end_frame), where start_frame Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive. is inclusive and end_frame is exclusive.
fps: Frame rate of the video. fps: Frame rate of the video.
vcodec: Video codec to use for encoding. camera_encoder_config: Video encoder settings
pix_fmt: Pixel format for output video. (``None`` uses :func:`~lerobot.datasets.video_utils.camera_encoder_defaults`).
""" """
if camera_encoder_config is None:
camera_encoder_config = camera_encoder_defaults()
from fractions import Fraction from fractions import Fraction
import av import av
@@ -619,12 +633,12 @@ def _keep_episodes_from_video_with_av(
# Convert fps to Fraction for PyAV compatibility. # Convert fps to Fraction for PyAV compatibility.
fps_fraction = Fraction(fps).limit_denominator(1000) fps_fraction = Fraction(fps).limit_denominator(1000)
v_out = out.add_stream(vcodec, rate=fps_fraction) v_out = out.add_stream(camera_encoder_config.vcodec, rate=fps_fraction)
# PyAV type stubs don't distinguish video streams from audio/subtitle streams. # PyAV type stubs don't distinguish video streams from audio/subtitle streams.
v_out.width = v_in.codec_context.width v_out.width = v_in.codec_context.width
v_out.height = v_in.codec_context.height v_out.height = v_in.codec_context.height
v_out.pix_fmt = pix_fmt v_out.pix_fmt = camera_encoder_config.pix_fmt
# Set time_base to match the frame rate for proper timestamp handling. # Set time_base to match the frame rate for proper timestamp handling.
v_out.time_base = Fraction(1, int(fps)) v_out.time_base = Fraction(1, int(fps))
@@ -687,8 +701,7 @@ def _copy_and_reindex_videos(
src_dataset: LeRobotDataset, src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata, dst_meta: LeRobotDatasetMetadata,
episode_mapping: dict[int, int], episode_mapping: dict[int, int],
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
pix_fmt: str = "yuv420p",
) -> dict[int, dict]: ) -> dict[int, dict]:
"""Copy and filter video files, only re-encoding files with deleted episodes. """Copy and filter video files, only re-encoding files with deleted episodes.
@@ -700,10 +713,14 @@ def _copy_and_reindex_videos(
src_dataset: Source dataset to copy from src_dataset: Source dataset to copy from
dst_meta: Destination metadata object dst_meta: Destination metadata object
episode_mapping: Mapping from old episode indices to new indices episode_mapping: Mapping from old episode indices to new indices
camera_encoder_config: Video encoder settings used when re-encoding segments
(``None`` uses :func:`~lerobot.datasets.video_utils.camera_encoder_defaults`).
Returns: Returns:
dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) dict mapping episode index to its video metadata (chunk_index, file_index, timestamps)
""" """
if camera_encoder_config is None:
camera_encoder_config = camera_encoder_defaults()
if src_dataset.meta.episodes is None: if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
@@ -792,8 +809,7 @@ def _copy_and_reindex_videos(
dst_video_path, dst_video_path,
episodes_to_keep_ranges, episodes_to_keep_ranges,
src_dataset.meta.fps, src_dataset.meta.fps,
vcodec, camera_encoder_config,
pix_fmt,
) )
cumulative_ts = 0.0 cumulative_ts = 0.0
@@ -897,14 +913,10 @@ def _copy_and_reindex_episodes_metadata(
dst_meta.finalize() dst_meta.finalize()
dst_meta.info.update( dst_meta.info.total_episodes = len(episode_mapping)
{ dst_meta.info.total_frames = total_frames
"total_episodes": len(episode_mapping), dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
"total_frames": total_frames, dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
"splits": {"train": f"0:{len(episode_mapping)}"},
}
)
write_info(dst_meta.info, dst_meta.root) write_info(dst_meta.info, dst_meta.root)
if not all_stats: if not all_stats:
@@ -1069,21 +1081,20 @@ def _copy_episodes_metadata_and_stats(
if episodes_dir.exists(): if episodes_dir.exists():
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True) shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
dst_meta.info.update( dst_meta.info.total_episodes = src_dataset.meta.total_episodes
{ dst_meta.info.total_frames = src_dataset.meta.total_frames
"total_episodes": src_dataset.meta.total_episodes, dst_meta.info.total_tasks = src_dataset.meta.total_tasks
"total_frames": src_dataset.meta.total_frames, # Preserve original splits if available, otherwise create default
"total_tasks": src_dataset.meta.total_tasks, dst_meta.info.splits = (
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}), src_dataset.meta.info.splits
} if src_dataset.meta.info.splits
else {"train": f"0:{src_dataset.meta.total_episodes}"}
) )
if dst_meta.video_keys and src_dataset.meta.video_keys: if dst_meta.video_keys and src_dataset.meta.video_keys:
for key in dst_meta.video_keys: for key in dst_meta.video_keys:
if key in src_dataset.meta.features: if key in src_dataset.meta.features:
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get( dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
"info", {}
)
write_info(dst_meta.info, dst_meta.root) write_info(dst_meta.info, dst_meta.root)
@@ -1269,11 +1280,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: list[int], episode_indices: list[int],
temp_dir: Path, temp_dir: Path,
fps: int, fps: int,
vcodec: str, camera_encoder_config: VideoEncoderConfig,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
num_calibration_frames: int = 30, num_calibration_frames: int = 30,
) -> float: ) -> float:
"""Estimate MB per frame by encoding a small calibration sample. """Estimate MB per frame by encoding a small calibration sample.
@@ -1287,11 +1294,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: List of episode indices being processed. episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files. temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding. fps: Frames per second for video encoding.
vcodec: Video codec (libsvtav1, h264, hevc). camera_encoder_config: Video encoder settings used for calibration encoding.
pix_fmt: Pixel format (yuv420p, etc.).
g: GOP size (group of pictures).
crf: Constant Rate Factor (quality).
fast_decode: Fast decode tuning parameter.
num_calibration_frames: Number of frames to use for calibration (default: 30). num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns: Returns:
@@ -1327,11 +1330,7 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir, imgs_dir=calibration_dir,
video_path=calibration_video_path, video_path=calibration_video_path,
fps=fps, fps=fps,
vcodec=vcodec, camera_encoder_config=camera_encoder_config,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True, overwrite=True,
) )
@@ -1525,7 +1524,7 @@ def modify_tasks(
write_tasks(new_task_df, root) write_tasks(new_task_df, root)
# Update info.json # Update info.json
dataset.meta.info["total_tasks"] = len(unique_tasks) dataset.meta.info.total_tasks = len(unique_tasks)
write_info(dataset.meta.info, root) write_info(dataset.meta.info, root)
# Reload metadata to reflect changes # Reload metadata to reflect changes
@@ -1649,11 +1648,7 @@ def convert_image_to_video_dataset(
dataset: LeRobotDataset, dataset: LeRobotDataset,
output_dir: Path | None = None, output_dir: Path | None = None,
repo_id: str | None = None, repo_id: str | None = None,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None, episode_indices: list[int] | None = None,
num_workers: int = 4, num_workers: int = 4,
max_episodes_per_batch: int | None = None, max_episodes_per_batch: int | None = None,
@@ -1668,11 +1663,8 @@ def convert_image_to_video_dataset(
dataset: The source LeRobot dataset with images dataset: The source LeRobot dataset with images
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
vcodec: Video codec (default: libsvtav1) camera_encoder_config: Video encoder settings
pix_fmt: Pixel format (default: yuv420p) (``None`` uses :func:`~lerobot.datasets.video_utils.camera_encoder_defaults`).
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes) episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4) num_workers: Number of threads for parallel processing (default: 4)
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit) max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
@@ -1681,6 +1673,9 @@ def convert_image_to_video_dataset(
Returns: Returns:
New LeRobotDataset with images encoded as videos New LeRobotDataset with images encoded as videos
""" """
if camera_encoder_config is None:
camera_encoder_config = camera_encoder_defaults()
# Check that it's an image dataset # Check that it's an image dataset
if len(dataset.meta.video_keys) > 0: if len(dataset.meta.video_keys) > 0:
raise ValueError( raise ValueError(
@@ -1704,7 +1699,10 @@ def convert_image_to_video_dataset(
logging.info( logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}" f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
) )
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}") logging.info(
f"Video codec: {camera_encoder_config.vcodec}, pixel format: {camera_encoder_config.pix_fmt}, "
f"GOP: {camera_encoder_config.g}, CRF: {camera_encoder_config.crf}"
)
# Create new features dict, converting image features to video features # Create new features dict, converting image features to video features
new_features = {} new_features = {}
@@ -1774,11 +1772,7 @@ def convert_image_to_video_dataset(
episode_indices=episode_indices, episode_indices=episode_indices,
temp_dir=temp_dir, temp_dir=temp_dir,
fps=fps, fps=fps,
vcodec=vcodec, camera_encoder_config=camera_encoder_config,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
) )
logging.info(f"Processing camera: {img_key}") logging.info(f"Processing camera: {img_key}")
@@ -1820,11 +1814,7 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir, imgs_dir=imgs_dir,
video_path=video_path, video_path=video_path,
fps=fps, fps=fps,
vcodec=vcodec, camera_encoder_config=camera_encoder_config,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True, overwrite=True,
) )
@@ -1858,10 +1848,10 @@ def convert_image_to_video_dataset(
episodes_df.to_parquet(episodes_path, index=False) episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info # Update metadata info
new_meta.info["total_episodes"] = len(episode_indices) new_meta.info.total_episodes = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values()) new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
new_meta.info["total_tasks"] = dataset.meta.total_tasks new_meta.info.total_tasks = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"} new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos) # Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first # We need to manually set video info since update_video_info() checks video_keys first
@@ -1870,7 +1860,9 @@ def convert_image_to_video_dataset(
video_path = new_meta.root / new_meta.video_path.format( video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0 video_key=img_key, chunk_index=0, file_index=0
) )
new_meta.info["features"][img_key]["info"] = get_video_info(video_path) new_meta.info.features[img_key]["info"] = get_video_info(
video_path, camera_encoder_config=camera_encoder_config
)
write_info(new_meta.info, new_meta.root) write_info(new_meta.info, new_meta.root)

View File

@@ -52,6 +52,8 @@ from .utils import (
) )
from .video_utils import ( from .video_utils import (
StreamingVideoEncoder, StreamingVideoEncoder,
VideoEncoderConfig,
camera_encoder_defaults,
concatenate_video_files, concatenate_video_files,
encode_video_frames, encode_video_frames,
get_video_duration_in_s, get_video_duration_in_s,
@@ -65,14 +67,19 @@ def _encode_video_worker(
episode_index: int, episode_index: int,
root: Path, root: Path,
fps: int, fps: int,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
encoder_threads: int | None = None, encoder_threads: int | None = None,
) -> Path: ) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent img_dir = (root / fpath).parent
encode_video_frames( encode_video_frames(
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads img_dir,
temp_path,
fps,
camera_encoder_config=camera_encoder_config,
encoder_threads=encoder_threads,
overwrite=True,
) )
shutil.rmtree(img_dir) shutil.rmtree(img_dir)
return temp_path return temp_path
@@ -89,20 +96,22 @@ class DatasetWriter:
self, self,
meta: LeRobotDatasetMetadata, meta: LeRobotDatasetMetadata,
root: Path, root: Path,
vcodec: str, camera_encoder_config: VideoEncoderConfig | None,
encoder_threads: int | None, encoder_threads: int | None,
batch_encoding_size: int, batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None, streaming_encoder: StreamingVideoEncoder | None = None,
initial_frames: int = 0, initial_frames: int = 0,
): ):
"""Initialize the writer with metadata, codec, and encoding config. """Initialize the writer with metadata, codec, and encoder config.
Args: Args:
meta: Dataset metadata instance (used for feature schema, chunk meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence). settings, and episode persistence).
root: Local dataset root directory. root: Local dataset root directory.
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``). camera_encoder_config: Video encoder settings applied to all cameras.
encoder_threads: Threads per encoder instance. ``None`` for auto. ``None`` uses :func:`~lerobot.datasets.video_utils.camera_encoder_defaults`.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
batch_encoding_size: Number of episodes to accumulate before batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos. batch-encoding videos.
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder` streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
@@ -111,7 +120,7 @@ class DatasetWriter:
""" """
self._meta = meta self._meta = meta
self._root = root self._root = root
self._vcodec = vcodec self._camera_encoder_config = camera_encoder_config or camera_encoder_defaults()
self._encoder_threads = encoder_threads self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder self._streaming_encoder = streaming_encoder
@@ -284,7 +293,7 @@ class DatasetWriter:
episode_index, episode_index,
self._root, self._root,
self._meta.fps, self._meta.fps,
self._vcodec, self._camera_encoder_config,
self._encoder_threads, self._encoder_threads,
): video_key ): video_key
for video_key in self._meta.video_keys for video_key in self._meta.video_keys
@@ -495,7 +504,7 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded) # Update video info (only needed when first episode is encoded)
if episode_index == 0: if episode_index == 0:
self._meta.update_video_info(video_key) self._meta.update_video_info(video_key, camera_encoder_config=self._camera_encoder_config)
write_info(self._meta.info, self._meta.root) write_info(self._meta.info, self._meta.root)
metadata = { metadata = {
@@ -564,7 +573,12 @@ class DatasetWriter:
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png into mp4 videos.""" """Use ffmpeg to convert frames stored as png into mp4 videos."""
return _encode_video_worker( return _encode_video_worker(
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads video_key,
episode_index,
self._root,
self._meta.fps,
self._camera_encoder_config,
self._encoder_threads,
) )
def close_writer(self) -> None: def close_writer(self) -> None:

View File

@@ -19,6 +19,7 @@ from pprint import pformat
import torch import torch
from lerobot.configs import PreTrainedConfig from lerobot.configs import PreTrainedConfig
from lerobot.configs.rewards import RewardModelConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.transforms import ImageTransforms from lerobot.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
def resolve_delta_timestamps( def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None: ) -> dict[str, list] | None:
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig. """Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
Args: Args:
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from. cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
``{observation,action,reward}_delta_indices`` properties used below.
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
delta_timestamps against. delta_timestamps against.
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
ds_meta = LeRobotDatasetMetadata( ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
) )
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
if not cfg.dataset.streaming: if not cfg.dataset.streaming:
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset.repo_id, cfg.dataset.repo_id,

View File

@@ -28,6 +28,7 @@ from .utils import (
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
DatasetInfo,
) )
@@ -78,8 +79,8 @@ def create_empty_dataset_info(
chunks_size: int | None = None, chunks_size: int | None = None,
data_files_size_in_mb: int | None = None, data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None,
) -> dict: ) -> DatasetInfo:
"""Create a template dictionary for a new dataset's `info.json`. """Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
Args: Args:
codebase_version (str): The version of the LeRobot codebase. codebase_version (str): The version of the LeRobot codebase.
@@ -87,25 +88,24 @@ def create_empty_dataset_info(
features (dict): The LeRobot features dictionary for the dataset. features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos. use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any. robot_type (str | None): The type of robot used, if any.
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
Returns: Returns:
dict: A dictionary with the initial dataset metadata. DatasetInfo: A typed dataset information object with initial metadata.
""" """
return { return DatasetInfo(
"codebase_version": codebase_version, codebase_version=codebase_version,
"robot_type": robot_type, fps=fps,
"total_episodes": 0, features=features,
"total_frames": 0, robot_type=robot_type,
"total_tasks": 0, chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, data_path=DEFAULT_DATA_PATH,
"fps": fps, video_path=DEFAULT_VIDEO_PATH if use_videos else None,
"splits": {}, )
"data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
def check_delta_timestamps( def check_delta_timestamps(

View File

@@ -39,6 +39,7 @@ from .utils import (
EPISODES_DIR, EPISODES_DIR,
INFO_PATH, INFO_PATH,
STATS_PATH, STATS_PATH,
DatasetInfo,
serialize_dict, serialize_dict,
) )
@@ -115,25 +116,21 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
return dataset return dataset
def write_info(info: dict, local_dir: Path) -> None: def write_info(info: DatasetInfo, local_dir: Path) -> None:
write_json(info, local_dir / INFO_PATH) write_json(info.to_dict(), local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict: def load_info(local_dir: Path) -> DatasetInfo:
"""Load dataset info metadata from its standard file path. """Load dataset info metadata from its standard file path.
Also converts shape lists to tuples for consistency.
Args: Args:
local_dir (Path): The root directory of the dataset. local_dir (Path): The root directory of the dataset.
Returns: Returns:
dict: The dataset information dictionary. DatasetInfo: The typed dataset information object.
""" """
info = load_json(local_dir / INFO_PATH) raw = load_json(local_dir / INFO_PATH)
for ft in info["features"].values(): return DatasetInfo.from_dict(raw)
ft["shape"] = tuple(ft["shape"])
return info
def write_stats(stats: dict, local_dir: Path) -> None: def write_stats(stats: dict, local_dir: Path) -> None:

View File

@@ -36,8 +36,8 @@ from .utils import (
) )
from .video_utils import ( from .video_utils import (
StreamingVideoEncoder, StreamingVideoEncoder,
get_safe_default_codec, VideoEncoderConfig,
resolve_vcodec, get_safe_default_video_backend,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -58,10 +58,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None, video_backend: str | None = None,
return_uint8: bool = False, return_uint8: bool = False,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
streaming_encoding: bool = False, streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30, encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
): ):
""" """
2 modes are available for instantiating this class, depending on 2 different use cases: 2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -177,16 +177,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc', camera_encoder_config (VideoEncoderConfig | None, optional): Video encoder settings for cameras
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'. (codec, quality, etc.). When ``None``, :func:`~lerobot.datasets.video_utils.camera_encoder_defaults`
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder. is used by the writer.
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
codec decide.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False. instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
streaming encoding. Defaults to 30 (~1s at 30fps). streaming encoding. Defaults to 30 (~1s at 30fps).
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
libsvtav1 and 'threads' for h264/hevc.
Note: Note:
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
@@ -202,10 +201,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec() self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
self._return_uint8 = return_uint8 self._return_uint8 = return_uint8
self._batch_encoding_size = batch_encoding_size self._batch_encoding_size = batch_encoding_size
self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads self._encoder_threads = encoder_threads
if self._requested_root is not None: if self._requested_root is not None:
@@ -251,13 +249,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None streaming_enc = None
if streaming_encoding and len(self.meta.video_keys) > 0: if streaming_encoding and len(self.meta.video_keys) > 0:
streaming_enc = self._build_streaming_encoder( streaming_enc = self._build_streaming_encoder(
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads self.meta.fps,
camera_encoder_config,
self._encoder_threads,
encoder_queue_maxsize,
) )
self.writer = DatasetWriter( self.writer = DatasetWriter(
meta=self.meta, meta=self.meta,
root=self.root, root=self.root,
vcodec=self._vcodec, camera_encoder_config=camera_encoder_config,
encoder_threads=encoder_threads, encoder_threads=self._encoder_threads,
batch_encoding_size=batch_encoding_size, batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc, streaming_encoder=streaming_enc,
initial_frames=self.meta.total_frames, initial_frames=self.meta.total_frames,
@@ -298,19 +299,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def _build_streaming_encoder( def _build_streaming_encoder(
fps: int, fps: int,
vcodec: str, camera_encoder_config: VideoEncoderConfig | None,
encoder_queue_maxsize: int,
encoder_threads: int | None, encoder_threads: int | None,
encoder_queue_maxsize: int,
) -> StreamingVideoEncoder: ) -> StreamingVideoEncoder:
return StreamingVideoEncoder( return StreamingVideoEncoder(
fps=fps, fps=fps,
vcodec=vcodec, camera_encoder_config=camera_encoder_config,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
queue_maxsize=encoder_queue_maxsize,
) )
# ── Metadata properties ─────────────────────────────────────────── # ── Metadata properties ───────────────────────────────────────────
@@ -625,11 +622,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0, image_writer_threads: int = 0,
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
metadata_buffer_size: int = 10, metadata_buffer_size: int = 10,
streaming_encoding: bool = False, streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30, encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None, encoder_threads: int | None = None,
video_files_size_in_mb: int | None = None,
data_files_size_in_mb: int | None = None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a new LeRobotDataset from scratch for recording data. """Create a new LeRobotDataset from scratch for recording data.
@@ -654,20 +653,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend (used when reading back). video_backend: Video decoding backend (used when reading back).
batch_encoding_size: Number of episodes to accumulate before batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos. ``1`` means encode immediately. batch-encoding videos. ``1`` means encode immediately.
vcodec: Video codec for encoding. Options include ``'libsvtav1'``, camera_encoder_config: Video encoder settings for cameras (codec, quality, etc.).
``'h264'``, ``'hevc'``, ``'auto'``. When ``None``, :func:`~lerobot.datasets.video_utils.camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
metadata_buffer_size: Number of episode metadata records to buffer metadata_buffer_size: Number of episode metadata records to buffer
before flushing to parquet. before flushing to parquet.
streaming_encoding: If ``True``, encode video frames in real-time streaming_encoding: If ``True``, encode video frames in real-time
during capture instead of writing images first. during capture instead of writing images first.
encoder_queue_maxsize: Max buffered frames per camera when using encoder_queue_maxsize: Max buffered frames per camera when using
streaming encoding. streaming encoding.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns: Returns:
A new :class:`LeRobotDataset` in write mode. A new :class:`LeRobotDataset` in write mode.
""" """
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create( obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id, repo_id=repo_id,
@@ -677,6 +676,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
root=root, root=root,
use_videos=use_videos, use_videos=use_videos,
metadata_buffer_size=metadata_buffer_size, metadata_buffer_size=metadata_buffer_size,
video_files_size_in_mb=video_files_size_in_mb,
data_files_size_in_mb=data_files_size_in_mb,
) )
obj.repo_id = obj.meta.repo_id obj.repo_id = obj.meta.repo_id
obj._requested_root = obj.meta.root obj._requested_root = obj.meta.root
@@ -686,23 +687,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None obj.image_transforms = None
obj.delta_timestamps = None obj.delta_timestamps = None
obj.episodes = None obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
obj._return_uint8 = False obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads obj._encoder_threads = encoder_threads
# Reader is lazily created on first access (write-only mode) # Reader is lazily created on first access (write-only mode)
obj.reader = None obj.reader = None
# Create writer
streaming_enc = None streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0: if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads) streaming_enc = cls._build_streaming_encoder(
fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
)
obj.writer = DatasetWriter( obj.writer = DatasetWriter(
meta=obj.meta, meta=obj.meta,
root=obj.root, root=obj.root,
vcodec=vcodec, camera_encoder_config=camera_encoder_config,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size, batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc, streaming_encoder=streaming_enc,
@@ -725,12 +726,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False, force_cache_sync: bool = False,
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
image_writer_processes: int = 0, image_writer_processes: int = 0,
image_writer_threads: int = 0, image_writer_threads: int = 0,
streaming_encoding: bool = False, streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30, encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Resume recording on an existing dataset. """Resume recording on an existing dataset.
@@ -753,13 +754,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend for reading back data. video_backend: Video decoding backend for reading back data.
batch_encoding_size: Number of episodes to accumulate before batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos. batch-encoding videos.
vcodec: Video codec for encoding. camera_encoder_config: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.datasets.video_utils.camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
image_writer_processes: Subprocesses for async image writing. image_writer_processes: Subprocesses for async image writing.
image_writer_threads: Threads for async image writing. image_writer_threads: Threads for async image writing.
streaming_encoding: If ``True``, encode video in real-time during streaming_encoding: If ``True``, encode video in real-time during
capture. capture.
encoder_queue_maxsize: Max buffered frames per camera for streaming. encoder_queue_maxsize: Max buffered frames per camera for streaming.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns: Returns:
A :class:`LeRobotDataset` in write mode, ready to append episodes. A :class:`LeRobotDataset` in write mode, ready to append episodes.
@@ -770,7 +773,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt " "Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
"the shared cache. Please provide a local directory path." "the shared cache. Please provide a local directory path."
) )
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj._requested_root = Path(root) obj._requested_root = Path(root)
@@ -779,11 +781,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None obj.image_transforms = None
obj.delta_timestamps = None obj.delta_timestamps = None
obj.episodes = None obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_codec() obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
obj._return_uint8 = False obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
if obj._requested_root is not None: if obj._requested_root is not None:
obj._requested_root.mkdir(exist_ok=True, parents=True) obj._requested_root.mkdir(exist_ok=True, parents=True)
@@ -792,21 +792,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.meta = LeRobotDatasetMetadata( obj.meta = LeRobotDatasetMetadata(
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
) )
obj._encoder_threads = encoder_threads
obj.root = obj.meta.root obj.root = obj.meta.root
# Reader is lazily created on first access (write-only mode) # Reader is lazily created on first access (write-only mode)
obj.reader = None obj.reader = None
# Create writer for appending
streaming_enc = None streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0: if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder( streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads obj.meta.fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
) )
obj.writer = DatasetWriter( obj.writer = DatasetWriter(
meta=obj.meta, meta=obj.meta,
root=obj.root, root=obj.root,
vcodec=vcodec, camera_encoder_config=camera_encoder_config,
encoder_threads=encoder_threads, encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size, batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc, streaming_encoder=streaming_enc,

View File

@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
""" """
return self._datasets[0].meta.info["fps"] return self._datasets[0].meta.info.fps
@property @property
def video(self) -> bool: def video(self) -> bool:
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
""" """
return self._datasets[0].meta.info.get("video", False) return len(self._datasets[0].meta.video_keys) > 0
@property @property
def features(self) -> datasets.Features: def features(self) -> datasets.Features:

View File

@@ -0,0 +1,186 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
Checks degrade to a no-op when the target codec isn't available locally.
"""
from __future__ import annotations
import functools
import logging
from typing import TYPE_CHECKING, Any
import av
if TYPE_CHECKING:
from lerobot.datasets.video_utils import VideoEncoderConfig
logger = logging.getLogger(__name__)
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
try:
return av.codec.Codec(vcodec, "w")
except Exception:
return None
@functools.cache
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
codec = get_codec(vcodec)
if codec is None:
return {}
return {opt.name: opt for opt in codec.descriptor.options}
@functools.cache
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
codec = get_codec(vcodec)
if codec is None:
return ()
return tuple(fmt.name for fmt in (codec.video_formats or []))
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
Each name is probed directly via :func:`get_codec`; input order is preserved.
"""
if isinstance(encoders, str):
encoders = [encoders]
available: list[str] = []
for name in encoders:
codec = get_codec(name)
if codec is not None and codec.type == "video":
available.append(name)
else:
logger.debug("encoder '%s' not available as video encoder", name)
return available
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
type_name = opt.type.name
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
if isinstance(value, bool):
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
elif isinstance(value, str):
try:
num_val = float(value)
except ValueError as e:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
) from e
elif isinstance(value, (float, int)):
num_val = value
else:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
# Check integer type compatibility
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
raise ValueError(
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
)
# Check numeric range compatibility
lo, hi = float(opt.min), float(opt.max)
if lo < hi and not (lo <= num_val <= hi):
raise ValueError(
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
)
elif type_name == "STRING":
if isinstance(value, bool):
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
if isinstance(value, str):
str_val = value
elif isinstance(value, (int, float)):
str_val = str(value)
else:
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
# Check string choice compatibility
choices = [c.name for c in (opt.choices or [])]
if choices and str_val not in choices:
raise ValueError(
f"{label}={str_val!r} is not a supported choice for codec "
f"{vcodec!r}; valid choices: {choices}"
)
else:
return
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
formats = _get_codec_video_formats(vcodec)
if formats and pix_fmt not in formats:
raise ValueError(
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
f"supported pixel formats: {list(formats)}"
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any], config: VideoEncoderConfig) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec)
for key, value in codec_options.items():
# GOP size is not a codec-specific option, it has to be validated separately.
if key == "g":
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
continue
if key not in supported_options:
continue
opt = supported_options[key]
label = f"extra_options[{key!r}]" if key in config.extra_options else key
_check_option_value(vcodec, label, value, opt)
def check_video_encoder_config_pyav(config: VideoEncoderConfig) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.datasets.video_utils.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict).
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises:
ValueError: on the first incompatibility encountered.
"""
vcodec = config.vcodec
options = _get_codec_options_by_name(vcodec)
if not options:
logger.warning(
"Codec %r is not available in the bundled FFmpeg build; ",
vcodec,
)
return
_check_pixel_format(config.vcodec, config.pix_fmt)
_check_codec_options(config.vcodec, config.get_codec_options(), config)

View File

@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
def _make_padding_camera_frame(self, camera_key: str): def _make_padding_camera_frame(self, camera_key: str):
"""Variable-shape padding frame for given camera keys, given in (H, W, C)""" """Variable-shape padding frame for given camera keys, given in (H, W, C)"""
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1) return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
def _get_video_frame_padding_mask( def _get_video_frame_padding_mask(
self, self,

View File

@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import dataclasses
import importlib.resources import importlib.resources
import json import json
import logging import logging
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
import datasets import datasets
@@ -70,6 +72,9 @@ class ForwardCompatibilityError(CompatibilityError):
super().__init__(message) super().__init__(message)
logger = logging.getLogger(__name__)
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
@@ -94,6 +99,123 @@ LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl" LEGACY_TASKS_PATH = "meta/tasks.jsonl"
@dataclass
class DatasetInfo:
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
created by ``create_empty_dataset_info()``. Using a dataclass provides
explicit field definitions, IDE auto-completion, and validation at
construction time.
"""
codebase_version: str
fps: int
features: dict[str, dict]
# Episode / frame counters — start at zero for new datasets
total_episodes: int = 0
total_frames: int = 0
total_tasks: int = 0
# Storage settings
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
# File path templates
data_path: str = field(default=DEFAULT_DATA_PATH)
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
# Optional metadata
robot_type: str | None = None
splits: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
# Coerce feature shapes from list to tuple — JSON deserialisation
# returns lists, but the rest of the codebase expects tuples.
for ft in self.features.values():
if isinstance(ft.get("shape"), list):
ft["shape"] = tuple(ft["shape"])
if self.fps <= 0:
raise ValueError(f"fps must be positive, got {self.fps}")
if self.chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
if self.data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
if self.video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
def to_dict(self) -> dict:
"""Return a JSON-serialisable dict.
Converts tuple shapes back to lists so ``json.dump`` can handle them.
"""
d = dataclasses.asdict(self)
for ft in d["features"].values():
if isinstance(ft.get("shape"), tuple):
ft["shape"] = list(ft["shape"])
return d
@classmethod
def from_dict(cls, data: dict) -> "DatasetInfo":
"""Construct from a raw dict (e.g. loaded directly from JSON).
Unknown keys are ignored for forward compatibility with datasets that
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
logged when such fields are present.
"""
known = {f.name for f in dataclasses.fields(cls)}
unknown = sorted(k for k in data if k not in known)
if unknown:
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
return cls(**{k: v for k, v in data.items() if k in known})
# ---------------------------------------------------------------------------
# Temporary dict-style compatibility layer
# Allows existing ``info["key"]`` call-sites to keep working without changes.
# Once all callers have been migrated to attribute access, remove these.
# ---------------------------------------------------------------------------
def __getitem__(self, key: str):
import warnings
warnings.warn(
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
f"Use attribute access info.{key} instead.",
DeprecationWarning,
stacklevel=2,
)
try:
return getattr(self, key)
except AttributeError as err:
raise KeyError(key) from err
def __setitem__(self, key: str, value) -> None:
import warnings
warnings.warn(
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
f"Use attribute assignment info.{key} = ... instead.",
DeprecationWarning,
stacklevel=2,
)
if not hasattr(self, key):
raise KeyError(f"DatasetInfo has no field '{key}'")
setattr(self, key, value)
def __contains__(self, key: str) -> bool:
"""Check if a field exists (dict-like interface)."""
return hasattr(self, key)
def get(self, key: str, default=None):
"""Get attribute value with default fallback (dict-like interface)."""
try:
return getattr(self, key)
except AttributeError:
return default
def has_legacy_hub_download_metadata(root: Path) -> bool: def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror. """Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
@@ -294,7 +416,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
def create_lerobot_dataset_card( def create_lerobot_dataset_card(
tags: list | None = None, tags: list | None = None,
dataset_info: dict | None = None, dataset_info: DatasetInfo | None = None,
**kwargs, **kwargs,
) -> DatasetCard: ) -> DatasetCard:
"""Create a `DatasetCard` for a LeRobot dataset. """Create a `DatasetCard` for a LeRobot dataset.
@@ -305,7 +427,7 @@ def create_lerobot_dataset_card(
Args: Args:
tags (list | None): A list of tags to add to the dataset card. tags (list | None): A list of tags to add to the dataset card.
dataset_info (dict | None): The dataset's info dictionary, which will dataset_info (DatasetInfo | None): The dataset's info object, which will
be displayed on the card. be displayed on the card.
**kwargs: Additional keyword arguments to populate the card template. **kwargs: Additional keyword arguments to populate the card template.
@@ -318,7 +440,7 @@ def create_lerobot_dataset_card(
card_tags += tags card_tags += tags
if dataset_info: if dataset_info:
dataset_structure = "[meta/info.json](meta/info.json):\n" dataset_structure = "[meta/info.json](meta/info.json):\n"
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n" dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
kwargs = {**kwargs, "dataset_structure": dataset_structure} kwargs = {**kwargs, "dataset_structure": dataset_structure}
card_data = DatasetCardData( card_data = DatasetCardData(
license=kwargs.get("license"), license=kwargs.get("license"),

View File

@@ -22,7 +22,7 @@ import shutil
import tempfile import tempfile
import threading import threading
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import asdict, dataclass, field
from fractions import Fraction from fractions import Fraction
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock
@@ -37,7 +37,11 @@ import torchvision
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from lerobot.utils.import_utils import get_safe_default_codec from lerobot.datasets.pyav_utils import (
check_video_encoder_config_pyav,
detect_available_encoders_pyav,
)
from lerobot.utils.import_utils import get_safe_default_video_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -54,68 +58,155 @@ HW_ENCODERS = [
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS) VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
LIBSVTAV1_DEFAULT_PRESET: int = 12
def _get_codec_options(
vcodec: str,
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
) -> dict:
"""Build codec-specific options dict for video encoding."""
options = {}
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
options["g"] = str(g)
# Quality control (codec-specific parameter names)
if crf is not None:
if vcodec in ("h264", "hevc", "libsvtav1"):
options["crf"] = str(crf)
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
quality = max(1, min(100, int(100 - crf * 2)))
options["q:v"] = str(quality)
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
options["rc"] = "constqp"
options["qp"] = str(crf)
elif vcodec in ("h264_vaapi",):
options["qp"] = str(crf)
elif vcodec in ("h264_qsv",):
options["global_quality"] = str(crf)
# Preset (only for libsvtav1)
if vcodec == "libsvtav1":
options["preset"] = str(preset) if preset is not None else "12"
return options
def detect_available_hw_encoders() -> list[str]: @dataclass
"""Probe PyAV/FFmpeg for available hardware video encoders.""" class VideoEncoderConfig:
available = [] """Video encoder configuration.
for codec_name in HW_ENCODERS:
try: Attributes:
av.codec.Codec(codec_name, "w") vcodec: FFmpeg encoder name. ``"auto"`` is resolved during
available.append(codec_name) construction (HW encoder if available, else ``libsvtav1``).
except Exception: # nosec B110 pix_fmt: Pixel format (e.g. ``"yuv420p"``).
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110 g: GOP size (keyframe interval).
return available crf: Quality level — mapped to the native quality parameter of the
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
preset: Speed/quality preset. Accepted type is per-codec.
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
set ``tune=fastdecode``. Ignored for other codecs.
video_backend: Python library driving FFmpeg for encoding. Only ``"pyav"``
is currently supported.
extra_options: Free-form dictionary of additional FFmpeg options
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
"""
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
g: int | None = 2
crf: int | None = 30
preset: int | str | None = None
fast_decode: int = 0
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
# two backends (encoding and decoding).
video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
if self.preset is None and self.vcodec == "libsvtav1":
self.preset = LIBSVTAV1_DEFAULT_PRESET
self.validate()
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
"""Detect available encoders based on the video backend."""
if self.video_backend == "pyav":
return detect_available_encoders_pyav(encoders)
else:
return []
def validate(self) -> None:
"""Validate the video encoder config."""
if self.video_backend == "pyav":
check_video_encoder_config_pyav(self)
def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
For ``"auto"``, the first hardware encoder in the preference list that FFmpeg
exposes is chosen; if none are available, ``libsvtav1`` is used. If the
resolved codec (explicit or after auto-selection) is not present in the
local FFmpeg build, raises ``ValueError``.
"""
if self.vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if self.vcodec == "auto":
available = self.detect_available_encoders(HW_ENCODERS)
for encoder in HW_ENCODERS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
self.vcodec = encoder
return
logger.warning("No hardware encoder available, falling back to software encoder 'libsvtav1'")
self.vcodec = "libsvtav1"
if self.detect_available_encoders(self.vcodec):
logger.info(f"Using video codec: {self.vcodec}")
self.vcodec = self.vcodec
return
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
def get_codec_options(
self, encoder_threads: int | None = None, as_strings: bool = False
) -> dict[str, str]:
"""Translate the tuning fields to codec-specific FFmpeg options.
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
Args:
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
For h264/hevc, this is mapped to ``threads``.
Hardware encoders ignore this parameter.
as_strings: If ``True``, casts values to strings.
"""
opts: dict[str, Any] = {}
def set_if(key: str, value: Any) -> None:
if value is not None:
opts[key] = value if not as_strings else str(value)
# GOP size is not a codec-specific option, so it is always set.
set_if("g", self.g)
if self.vcodec == "libsvtav1":
set_if("crf", self.crf)
set_if("preset", self.preset)
svtav1_parts: list[str] = []
if self.fast_decode is not None:
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
if encoder_threads is not None:
svtav1_parts.append(f"lp={encoder_threads}")
if svtav1_parts:
opts["svtav1-params"] = ":".join(svtav1_parts)
elif self.vcodec in ("h264", "hevc"):
set_if("crf", self.crf)
set_if("preset", self.preset)
if self.fast_decode:
opts["tune"] = "fastdecode"
set_if("threads", encoder_threads)
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
if self.crf is not None:
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
opts["rc"] = "constqp"
set_if("qp", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "h264_vaapi":
set_if("qp", self.crf)
elif self.vcodec == "h264_qsv":
set_if("global_quality", self.crf)
set_if("preset", self.preset)
else:
set_if("crf", self.crf)
set_if("preset", self.preset)
# Extra options are merged last but never override structured fields (values are kept as given).
for k, v in self.extra_options.items():
if k not in opts:
set_if(k, v)
return opts
def resolve_vcodec(vcodec: str) -> str: def camera_encoder_defaults() -> VideoEncoderConfig:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.""" """Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
if vcodec not in VALID_VIDEO_CODECS: return VideoEncoderConfig()
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logger.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
return encoder
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
def decode_video_frames( def decode_video_frames(
@@ -142,7 +233,7 @@ def decode_video_frames(
Currently supports torchcodec on cpu and pyav. Currently supports torchcodec on cpu and pyav.
""" """
if backend is None: if backend is None:
backend = get_safe_default_codec() backend = get_safe_default_video_backend()
if backend == "torchcodec": if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend in ["pyav", "video_reader"]: elif backend in ["pyav", "video_reader"]:
@@ -400,18 +491,17 @@ def encode_video_frames(
imgs_dir: Path | str, imgs_dir: Path | str,
video_path: Path | str, video_path: Path | str,
fps: int, fps: int,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
pix_fmt: str = "yuv420p", encoder_threads: int | None = None,
g: int | None = 2, *,
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.WARNING, log_level: int | None = av.logging.WARNING,
overwrite: bool = False, overwrite: bool = False,
preset: int | None = None,
encoder_threads: int | None = None,
) -> None: ) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" """More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
vcodec = resolve_vcodec(vcodec) if camera_encoder_config is None:
camera_encoder_config = camera_encoder_defaults()
vcodec = camera_encoder_config.vcodec
pix_fmt = camera_encoder_config.pix_fmt
video_path = Path(video_path) video_path = Path(video_path)
imgs_dir = Path(imgs_dir) imgs_dir = Path(imgs_dir)
@@ -422,42 +512,18 @@ def encode_video_frames(
video_path.parent.mkdir(parents=True, exist_ok=True) video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logger.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
# Get input frames # Get input frames
template = "frame-" + ("[0-9]" * 6) + ".png" template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted( input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0]) glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
) )
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0: if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.") raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image: with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size width, height = dummy_image.size
# Define video codec options video_options = camera_encoder_config.get_codec_options(encoder_threads, as_strings=True)
video_options = _get_codec_options(vcodec, g, crf, preset)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if encoder_threads is not None:
if vcodec == "libsvtav1":
lp_param = f"lp={encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(encoder_threads)
# Set logging level # Set logging level
if log_level is not None: if log_level is not None:
@@ -494,7 +560,10 @@ def encode_video_frames(
def concatenate_video_files( def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True input_video_paths: list[Path | str],
output_video_path: Path,
overwrite: bool = True,
compatibility_check: bool = False,
): ):
""" """
Concatenate multiple video files into a single video file using pyav. Concatenate multiple video files into a single video file using pyav.
@@ -507,6 +576,7 @@ def concatenate_video_files(
input_video_paths: Ordered list of input video file paths to concatenate. input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file. output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True. overwrite: Whether to overwrite the output video file if it already exists. Default is True.
compatibility_check: Whether to check if the input videos are compatible. Default is False.
Note: Note:
- Creates a temporary directory for intermediate files that is cleaned up after use. - Creates a temporary directory for intermediate files that is cleaned up after use.
@@ -525,6 +595,22 @@ def concatenate_video_files(
if len(input_video_paths) == 0: if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.") raise FileNotFoundError("No input video paths provided.")
# This check may be skipped at recording time as videos are encoded with the same encoder config.
if compatibility_check:
reference_video_info = get_video_info(input_video_paths[0])
for input_path in input_video_paths[1:]:
video_info = get_video_info(input_path)
if (
video_info["video.height"] != reference_video_info["video.height"]
or video_info["video.width"] != reference_video_info["video.width"]
or video_info["video.fps"] != reference_video_info["video.fps"]
or video_info["video.codec"] != reference_video_info["video.codec"]
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
):
raise ValueError(
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
)
# Create a temporary .ffconcat file to list the input video paths # Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file: with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n") tmp_concatenate_file.write("ffconcat version 1.0\n")
@@ -591,26 +677,20 @@ class _CameraEncoderThread(threading.Thread):
fps: int, fps: int,
vcodec: str, vcodec: str,
pix_fmt: str, pix_fmt: str,
g: int | None, codec_options: dict[str, str],
crf: int | None,
preset: int | None,
frame_queue: queue.Queue, frame_queue: queue.Queue,
result_queue: queue.Queue, result_queue: queue.Queue,
stop_event: threading.Event, stop_event: threading.Event,
encoder_threads: int | None = None,
): ):
super().__init__(daemon=True) super().__init__(daemon=True)
self.video_path = video_path self.video_path = video_path
self.fps = fps self.fps = fps
self.vcodec = vcodec self.vcodec = vcodec
self.pix_fmt = pix_fmt self.pix_fmt = pix_fmt
self.g = g self.codec_options = codec_options
self.crf = crf
self.preset = preset
self.frame_queue = frame_queue self.frame_queue = frame_queue
self.result_queue = result_queue self.result_queue = result_queue
self.stop_event = stop_event self.stop_event = stop_event
self.encoder_threads = encoder_threads
def run(self) -> None: def run(self) -> None:
from .compute_stats import RunningQuantileStats, auto_downsample_height_width from .compute_stats import RunningQuantileStats, auto_downsample_height_width
@@ -646,19 +726,9 @@ class _CameraEncoderThread(threading.Thread):
# Open container on first frame (to get width/height) # Open container on first frame (to get width/height)
if container is None: if container is None:
height, width = frame_data.shape[:2] height, width = frame_data.shape[:2]
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
if self.encoder_threads is not None:
if self.vcodec == "libsvtav1":
lp_param = f"lp={self.encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(self.encoder_threads)
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True) Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w") container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options) output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
output_stream.pix_fmt = self.pix_fmt output_stream.pix_fmt = self.pix_fmt
output_stream.width = width output_stream.width = width
output_stream.height = height output_stream.height = height
@@ -724,22 +794,25 @@ class StreamingVideoEncoder:
def __init__( def __init__(
self, self,
fps: int, fps: int,
vcodec: str = "libsvtav1", camera_encoder_config: VideoEncoderConfig | None = None,
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
queue_maxsize: int = 30,
encoder_threads: int | None = None, encoder_threads: int | None = None,
*,
queue_maxsize: int = 30,
): ):
"""
Args:
fps: Frames per second for the output videos.
camera_encoder_config: Video encoder settings applied to all cameras.
When ``None``, :func:`camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global setting).
``None`` lets the codec decide.
queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames.
"""
self.fps = fps self.fps = fps
self.vcodec = resolve_vcodec(vcodec) self._camera_encoder_config = camera_encoder_config or camera_encoder_defaults()
self.pix_fmt = pix_fmt self._encoder_threads = encoder_threads
self.g = g
self.crf = crf
self.preset = preset
self.queue_maxsize = queue_maxsize self.queue_maxsize = queue_maxsize
self.encoder_threads = encoder_threads
self._frame_queues: dict[str, queue.Queue] = {} self._frame_queues: dict[str, queue.Queue] = {}
self._result_queues: dict[str, queue.Queue] = {} self._result_queues: dict[str, queue.Queue] = {}
@@ -770,18 +843,19 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir)) temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4" video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
vcodec = self._camera_encoder_config.vcodec
codec_options = self._camera_encoder_config.get_codec_options(
self._encoder_threads, as_strings=True
)
encoder_thread = _CameraEncoderThread( encoder_thread = _CameraEncoderThread(
video_path=video_path, video_path=video_path,
fps=self.fps, fps=self.fps,
vcodec=self.vcodec, vcodec=vcodec,
pix_fmt=self.pix_fmt, pix_fmt=self._camera_encoder_config.pix_fmt,
g=self.g, codec_options=codec_options,
crf=self.crf,
preset=self.preset,
frame_queue=frame_queue, frame_queue=frame_queue,
result_queue=result_queue, result_queue=result_queue,
stop_event=stop_event, stop_event=stop_event,
encoder_threads=self.encoder_threads,
) )
encoder_thread.start() encoder_thread.start()
@@ -986,8 +1060,18 @@ def get_audio_info(video_path: Path | str) -> dict:
return audio_info return audio_info
def get_video_info(video_path: Path | str) -> dict: def get_video_info(
# Set logging level video_path: Path | str,
camera_encoder_config: "VideoEncoderConfig | None" = None,
) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args:
video_path: Path to the encoded video file to probe.
camera_encoder_config: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence — encoder fields are only written for keys
not already populated from the video file itself.
"""
logging.getLogger("libav").setLevel(av.logging.WARNING) logging.getLogger("libav").setLevel(av.logging.WARNING)
# Getting video stream information # Getting video stream information
@@ -1018,6 +1102,11 @@ def get_video_info(video_path: Path | str) -> dict:
# Adding audio stream information # Adding audio stream information
video_info.update(**get_audio_info(video_path)) video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided
if camera_encoder_config is not None:
for field_name, field_value in asdict(camera_encoder_config).items():
video_info.setdefault(f"video.{field_name}", field_value)
return video_info return video_info

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
from .act.configuration_act import ACTConfig as ACTConfig from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
@@ -21,10 +23,7 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .rtc import ActionInterpolator as ActionInterpolator
from .sac.configuration_sac import SACConfig as SACConfig from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference from .utils import make_robot_action, prepare_observation_for_inference
@@ -45,9 +44,7 @@ __all__ = [
"PI0Config", "PI0Config",
"PI0FastConfig", "PI0FastConfig",
"PI05Config", "PI05Config",
"RewardClassifierConfig",
"SACConfig", "SACConfig",
"SARMConfig",
"SmolVLAConfig", "SmolVLAConfig",
"TDMPCConfig", "TDMPCConfig",
"VQBeTConfig", "VQBeTConfig",

View File

@@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy):
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
).mean() num_valid = valid_mask.sum() * abs_err.shape[-1]
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
loss_dict = {"l1_loss": l1_loss.item()} loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae: if self.config.use_vae:

View File

@@ -380,7 +380,9 @@ class DiffusionModel(nn.Module):
f"{self.config.do_mask_loss_for_padding=}." f"{self.config.do_mask_loss_for_padding=}."
) )
in_episode_bound = ~batch["action_is_pad"] in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1) mask = in_episode_bound.unsqueeze(-1)
num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean() return loss.mean()

View File

@@ -52,8 +52,6 @@ from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config from .pi05.configuration_pi05 import PI05Config
from .pretrained import PreTrainedPolicy from .pretrained import PreTrainedPolicy
from .sac.configuration_sac import SACConfig from .sac.configuration_sac import SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency from .utils import validate_visual_features_consistency
@@ -89,7 +87,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args: Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
Returns: Returns:
The policy class corresponding to the given name. The policy class corresponding to the given name.
@@ -132,18 +130,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .sac.modeling_sac import SACPolicy from .sac.modeling_sac import SACPolicy
return SACPolicy return SACPolicy
elif name == "reward_classifier":
from .sac.reward_model.modeling_classifier import Classifier
return Classifier
elif name == "smolvla": elif name == "smolvla":
from .smolvla.modeling_smolvla import SmolVLAPolicy from .smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy return SmolVLAPolicy
elif name == "sarm":
from .sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot": elif name == "groot":
from .groot.modeling_groot import GrootPolicy from .groot.modeling_groot import GrootPolicy
@@ -173,7 +163,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args: Args:
policy_type: The type of the policy. Supported types include "tdmpc", policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x". "smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor. **kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns: Returns:
@@ -200,8 +190,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs) return SACConfig(**kwargs)
elif policy_type == "smolvla": elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs) return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot": elif policy_type == "groot":
return GrootConfig(**kwargs) return GrootConfig(**kwargs)
elif policy_type == "xvla": elif policy_type == "xvla":
@@ -378,14 +366,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"), dataset_stats=kwargs.get("dataset_stats"),
) )
elif isinstance(policy_cfg, RewardClassifierConfig):
from .sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SmolVLAConfig): elif isinstance(policy_cfg, SmolVLAConfig):
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
@@ -394,14 +374,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"), dataset_stats=kwargs.get("dataset_stats"),
) )
elif isinstance(policy_cfg, SARMConfig):
from .sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig): elif isinstance(policy_cfg, GrootConfig):
from .groot.processor_groot import make_groot_pre_post_processors from .groot.processor_groot import make_groot_pre_post_processors

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -174,17 +173,14 @@ N_COLOR_CHANNELS = 3
# config # config
@dataclass
class GR00TN15Config(PretrainedConfig): class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5" model_type = "gr00t_n1_5"
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."}) backbone_cfg: dict
action_head_cfg: dict
action_horizon: int = field(init=False, metadata={"help": "Action horizon."}) action_horizon: int
action_dim: int
action_dim: int = field(init=False, metadata={"help": "Action dimension."}) compute_dtype: str = "float32"
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module):
loss = F.mse_loss(predicted, target, reduction="none") loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch: if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"] mask = ~batch["action_is_pad"].unsqueeze(-1)
loss = loss * valid_actions.unsqueeze(-1) num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean() return loss.mean()
@@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module):
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none") loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch: if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"] mask = ~batch["action_is_pad"].unsqueeze(-1)
loss = loss * valid_mask.unsqueeze(-1) num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean() return loss.mean()

View File

@@ -227,6 +227,7 @@ class PI0FastPaliGemma(nn.Module):
# forward(..., adarms_cond=...) is supported (same as pi0/pi05). # forward(..., adarms_cond=...) is supported (same as pi0/pi05).
if use_adarms[0]: if use_adarms[0]:
text_config = self.paligemma.config.text_config text_config = self.paligemma.config.text_config
del self.paligemma.model.language_model
self.paligemma.model.language_model = PiGemmaModel(text_config) self.paligemma.model.language_model = PiGemmaModel(text_config)
self.to_bfloat16_for_selected_params(precision) self.to_bfloat16_for_selected_params(precision)

View File

@@ -197,6 +197,9 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
def __init__(self, config: GemmaConfig, **kwargs): def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
# Free parent-allocated layers/norm before replacing to avoid ~2x peak memory.
del self.layers
del self.norm
# if not getattr(config, "use_adarms", False): # if not getattr(config, "use_adarms", False):
# return # return
cond_dim = getattr(config, "adarms_cond_dim", None) cond_dim = getattr(config, "adarms_cond_dim", None)
@@ -328,6 +331,7 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
def __init__(self, config: GemmaConfig, **kwargs): def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
del self.model
self.model = PiGemmaModel(config) self.model = PiGemmaModel(config)
@@ -336,6 +340,7 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
del self.language_model
self.language_model = PiGemmaModel(config.text_config) self.language_model = PiGemmaModel(config.text_config)
@@ -344,6 +349,7 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
del self.model
self.model = PaliGemmaModelWithPiGemma(config) self.model = PaliGemmaModelWithPiGemma(config)
# Make modules available through conditional class for BC # Make modules available through conditional class for BC

View File

@@ -19,6 +19,7 @@ from .action_queue import ActionQueue
from .configuration_rtc import RTCConfig from .configuration_rtc import RTCConfig
from .latency_tracker import LatencyTracker from .latency_tracker import LatencyTracker
from .modeling_rtc import RTCProcessor from .modeling_rtc import RTCProcessor
from .relative import reanchor_relative_rtc_prefix
__all__ = [ __all__ = [
"ActionInterpolator", "ActionInterpolator",
@@ -26,4 +27,5 @@ __all__ = [
"LatencyTracker", "LatencyTracker",
"RTCConfig", "RTCConfig",
"RTCProcessor", "RTCProcessor",
"reanchor_relative_rtc_prefix",
] ]

View File

@@ -1,116 +1,4 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
# from lerobot.utils.action_interpolator import ActionInterpolator
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control. __all__ = ["ActionInterpolator"]
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)

View File

@@ -92,10 +92,10 @@ class ActionQueue:
Returns: Returns:
int: Number of unconsumed actions. int: Number of unconsumed actions.
""" """
if self.queue is None: with self.lock:
return 0 if self.queue is None:
length = len(self.queue) return 0
return length - self.last_index return len(self.queue) - self.last_index
def empty(self) -> bool: def empty(self) -> bool:
"""Check if the queue is empty. """Check if the queue is empty.
@@ -103,11 +103,10 @@ class ActionQueue:
Returns: Returns:
bool: True if no actions remain, False otherwise. bool: True if no actions remain, False otherwise.
""" """
if self.queue is None: with self.lock:
return True if self.queue is None:
return True
length = len(self.queue) return len(self.queue) - self.last_index <= 0
return length - self.last_index <= 0
def get_action_index(self) -> int: def get_action_index(self) -> int:
"""Get the current action consumption index. """Get the current action consumption index.
@@ -115,7 +114,8 @@ class ActionQueue:
Returns: Returns:
int: Index of the next action to be consumed. int: Index of the next action to be consumed.
""" """
return self.last_index with self.lock:
return self.last_index
def get_left_over(self) -> Tensor | None: def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over. """Get leftover original actions for RTC prev_chunk_left_over.

View File

@@ -35,7 +35,7 @@ class RTCConfig:
""" """
# Infrastructure # Infrastructure
enabled: bool = False enabled: bool = True
# Core RTC settings # Core RTC settings
# Todo change to exp # Todo change to exp

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Relative-action helpers for Real-Time Chunking (RTC)."""
from __future__ import annotations
import torch
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
to_relative_actions,
)
def reanchor_relative_rtc_prefix(
prev_actions_absolute: torch.Tensor,
current_state: torch.Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> torch.Tensor:
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
is stored in absolute coordinates. Before feeding it back to the policy, this
helper re-expresses those actions relative to the robot's current joint state
and optionally normalizes them so the policy receives correctly scaled inputs.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)

View File

@@ -1 +0,0 @@
../../../../docs/source/policy_sarm_README.md

View File

@@ -394,13 +394,21 @@ class SmolVLAPolicy(PreTrainedPolicy):
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item() loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
if reduction == "none": if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims # Return per-sample losses (B,) by averaging over valid (time, action) entries
per_sample_loss = losses.mean(dim=(1, 2)) if actions_is_pad is None:
per_sample_loss = losses.mean(dim=(1, 2))
else:
num_valid = ((~actions_is_pad).sum(dim=1) * losses.shape[-1]).clamp_min(1)
per_sample_loss = losses.sum(dim=(1, 2)) / num_valid
loss_dict["loss"] = per_sample_loss.mean().item() loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict return per_sample_loss, loss_dict
else: else:
# Default: return scalar mean loss # Default: return scalar mean loss over valid (time, action) entries
loss = losses.mean() if actions_is_pad is None:
loss = losses.mean()
else:
num_valid = ((~actions_is_pad).sum() * losses.shape[-1]).clamp_min(1)
loss = losses.sum() / num_valid
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict

View File

@@ -557,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
def __post_init__(self): def __post_init__(self):
"""Initializes the reward classifier model after the dataclass is created.""" """Initializes the reward classifier model after the dataclass is created."""
if self.pretrained_path is not None: if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier from lerobot.rewards.classifier.modeling_classifier import Classifier
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path) self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
self.reward_classifier.to(self.device) self.reward_classifier.to(self.device)

View File

@@ -142,6 +142,10 @@ class RelativeActionsProcessorStep(ProcessorStep):
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask) new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
return new_transition return new_transition
def get_cached_state(self) -> torch.Tensor | None:
"""Return the cached ``observation.state`` used as the reference point for relative/absolute action conversions."""
return self._last_state
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
return { return {
"enabled": self.enabled, "enabled": self.enabled,
@@ -182,7 +186,8 @@ class AbsoluteActionsProcessorStep(ProcessorStep):
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor." "but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
) )
if self.relative_step._last_state is None: cached_state = self.relative_step.get_cached_state()
if cached_state is None:
raise RuntimeError( raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep " "AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor." "but no state has been cached. Ensure the preprocessor runs before the postprocessor."
@@ -194,9 +199,7 @@ class AbsoluteActionsProcessorStep(ProcessorStep):
return new_transition return new_transition
mask = self.relative_step._build_mask(action.shape[-1]) mask = self.relative_step._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_absolute_actions( new_transition[TransitionKey.ACTION] = to_absolute_actions(action, cached_state, mask)
action, self.relative_step._last_state, mask
)
return new_transition return new_transition
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:

View File

@@ -0,0 +1,36 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .factory import (
get_reward_model_class as get_reward_model_class,
make_reward_model as make_reward_model,
make_reward_model_config as make_reward_model_config,
make_reward_pre_post_processors as make_reward_pre_post_processors,
)
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
from .sarm.configuration_sarm import SARMConfig as SARMConfig
__all__ = [
# Configuration classes
"RewardClassifierConfig",
"SARMConfig",
# Base class
"PreTrainedRewardModel",
# Factory functions
"get_reward_model_class",
"make_reward_model",
"make_reward_model_config",
"make_reward_pre_post_processors",
]

View File

@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +13,15 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.configs import NormalizationMode, PreTrainedConfig from lerobot.configs import NormalizationMode
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
from lerobot.utils.constants import OBS_IMAGE from lerobot.utils.constants import OBS_IMAGE
@PreTrainedConfig.register_subclass(name="reward_classifier") @RewardModelConfig.register_subclass(name="reward_classifier")
@dataclass @dataclass
class RewardClassifierConfig(PreTrainedConfig): class RewardClassifierConfig(RewardModelConfig):
"""Configuration for the Reward Classifier model.""" """Configuration for the Reward Classifier model."""
name: str = "reward_classifier" name: str = "reward_classifier"

View File

@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,11 +17,10 @@ import logging
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.constants import OBS_IMAGE, REWARD from lerobot.utils.constants import OBS_IMAGE, REWARD
from ...pretrained import PreTrainedPolicy
from .configuration_classifier import RewardClassifierConfig
class ClassifierOutput: class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata.""" """Wrapper for classifier outputs with additional metadata."""
@@ -99,7 +96,7 @@ class SpatialLearnedEmbeddings(nn.Module):
return output return output
class Classifier(PreTrainedPolicy): class Classifier(PreTrainedRewardModel):
"""Image classifier built on top of a pre-trained encoder.""" """Image classifier built on top of a pre-trained encoder."""
name = "reward_classifier" name = "reward_classifier"
@@ -235,6 +232,16 @@ class Classifier(PreTrainedPolicy):
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Returns 1.0 for success, 0.0 for failure based on image observations."""
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
output = self.predict(images)
if self.config.num_classes == 2:
return (output.probabilities > 0.5).float()
else:
return torch.argmax(output.probabilities, dim=1).float()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py.""" """Standard forward pass for training compatible with train.py."""
# Extract images and labels # Extract images and labels
@@ -269,10 +276,6 @@ class Classifier(PreTrainedPolicy):
def predict_reward(self, batch, threshold=0.5): def predict_reward(self, batch, threshold=0.5):
"""Eval method. Returns predicted reward with the decision threshold as argument.""" """Eval method. Returns predicted reward with the decision threshold as argument."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images from batch dict # Extract images from batch dict
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
@@ -282,28 +285,3 @@ class Classifier(PreTrainedPolicy):
return (probs > threshold).float() return (probs > threshold).float()
else: else:
return torch.argmax(self.predict(images).probabilities, dim=1) return torch.argmax(self.predict(images).probabilities, dim=1)
def get_optim_params(self):
"""Return optimizer parameters for the policy."""
return self.parameters()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
raise NotImplementedError("Reward classifiers do not select actions")
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not produce action chunks.
"""
raise NotImplementedError("Reward classifiers do not predict action chunks")
def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
pass

View File

@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -27,8 +25,7 @@ from lerobot.processor import (
policy_action_to_transition, policy_action_to_transition,
transition_to_policy_action, transition_to_policy_action,
) )
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from .configuration_classifier import RewardClassifierConfig
def make_classifier_processor( def make_classifier_processor(
@@ -52,8 +49,6 @@ def make_classifier_processor(
Args: Args:
config: The configuration object for the RewardClassifier. config: The configuration object for the RewardClassifier.
dataset_stats: A dictionary of statistics for normalization. dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns: Returns:
A tuple containing the configured pre-processor and post-processor pipelines. A tuple containing the configured pre-processor and post-processor pipelines.

View File

@@ -0,0 +1,238 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
from typing import Any
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
"""
Retrieves a reward model class by its registered name.
This function uses dynamic imports to avoid loading all reward model classes into
memory at once, improving startup time and reducing dependencies.
Args:
name: The name of the reward model. Supported names are "reward_classifier",
"sarm".
Returns:
The reward model class corresponding to the given name.
Raises:
ValueError: If the reward model name is not recognized.
"""
if name == "reward_classifier":
from lerobot.rewards.classifier.modeling_classifier import Classifier
return Classifier
elif name == "sarm":
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
else:
try:
return _get_reward_model_cls_from_name(name=name)
except Exception as e:
raise ValueError(f"Reward model type '{name}' is not available.") from e
def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
"""
Instantiates a reward model configuration object based on the reward type.
This factory function simplifies the creation of reward model configuration objects
by mapping a string identifier to the corresponding config class.
Args:
reward_type: The type of the reward model. Supported types include
"reward_classifier", "sarm".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
An instance of a `RewardModelConfig` subclass.
Raises:
ValueError: If the `reward_type` is not recognized.
"""
if reward_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif reward_type == "sarm":
return SARMConfig(**kwargs)
else:
try:
config_cls = RewardModelConfig.get_choice_class(reward_type)
return config_cls(**kwargs)
except Exception as e:
raise ValueError(f"Reward model type '{reward_type}' is not available.") from e
def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel:
"""
Instantiate a reward model from its configuration.
Args:
cfg: The configuration for the reward model to be created. If
`cfg.pretrained_path` is set, the model will be loaded with weights
from that path.
**kwargs: Additional keyword arguments forwarded to the model constructor
(e.g., ``dataset_stats``, ``dataset_meta``).
Returns:
An instantiated and device-placed reward model.
"""
reward_cls = get_reward_model_class(cfg.type)
kwargs["config"] = cfg
if cfg.pretrained_path:
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
reward_model = reward_cls.from_pretrained(**kwargs)
else:
reward_model = reward_cls(**kwargs)
reward_model.to(cfg.device)
assert isinstance(reward_model, torch.nn.Module)
return reward_model
def make_reward_pre_post_processors(
reward_cfg: RewardModelConfig,
**kwargs,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Create pre- and post-processor pipelines for a given reward model.
Each reward model type has a dedicated factory function for its processors.
Args:
reward_cfg: The configuration of the reward model for which to create processors.
**kwargs: Additional keyword arguments passed to the processor factory
(e.g., ``dataset_stats``, ``dataset_meta``).
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
Raises:
ValueError: If a processor factory is not implemented for the given reward
model configuration type.
"""
# Create a new processor based on reward model type
if isinstance(reward_cfg, RewardClassifierConfig):
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
return make_classifier_processor(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, SARMConfig):
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
return make_sarm_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
else:
try:
processors = _make_processors_from_reward_model_config(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
except Exception as e:
raise ValueError(
f"Processor for reward model type '{reward_cfg.type}' is not implemented."
) from e
return processors
def _get_reward_model_cls_from_name(name: str) -> type[PreTrainedRewardModel]:
"""Get reward model class from its registered name using dynamic imports.
This is used as a helper function to import reward models from 3rd party lerobot
plugins.
Args:
name: The name of the reward model.
Returns:
The reward model class corresponding to the given name.
"""
if name not in RewardModelConfig.get_known_choices():
raise ValueError(
f"Unknown reward model name '{name}'. "
f"Available reward models: {RewardModelConfig.get_known_choices()}"
)
config_cls = RewardModelConfig.get_choice_class(name)
config_cls_name = config_cls.__name__
model_name = config_cls_name.removesuffix("Config")
if model_name == config_cls_name:
raise ValueError(
f"The config class name '{config_cls_name}' does not follow the expected naming convention. "
f"Make sure it ends with 'Config'!"
)
cls_name = model_name + "RewardModel"
module_path = config_cls.__module__.replace("configuration_", "modeling_")
module = importlib.import_module(module_path)
reward_cls = getattr(module, cls_name)
return reward_cls
def _make_processors_from_reward_model_config(
config: RewardModelConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[Any, Any]:
"""Create pre- and post-processors from a reward model configuration using dynamic imports.
This is used as a helper function to import processor factories from 3rd party
lerobot reward model plugins.
Args:
config: The reward model configuration object.
dataset_stats: Dataset statistics for normalization.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
"""
reward_type = config.type
function_name = f"make_{reward_type}_pre_post_processors"
module_path = config.__class__.__module__.replace("configuration_", "processor_")
logging.debug(
f"Instantiating reward pre/post processors using function '{function_name}' "
f"from module '{module_path}'"
)
module = importlib.import_module(module_path)
function = getattr(module, function_name)
return function(config, dataset_stats=dataset_stats)

View File

@@ -0,0 +1,244 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import builtins
import logging
import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, TypeVar
import packaging
import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.configs.rewards import RewardModelConfig
from lerobot.utils.hub import HubMixin
if TYPE_CHECKING:
from lerobot.configs.train import TrainPipelineConfig
T = TypeVar("T", bound="PreTrainedRewardModel")
class PreTrainedRewardModel(nn.Module, HubMixin, abc.ABC):
"""Base class for reward models."""
config_class: None
name: None
def __init__(self, config: RewardModelConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, RewardModelConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`RewardModelConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.config = config
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not getattr(cls, "config_class", None):
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: RewardModelConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
) -> T:
"""
The reward model is set in evaluation mode by default using `reward.eval()` (dropout modules are
deactivated). To train it, you should first set it back in training mode with `reward.train()`.
"""
if config is None:
config = RewardModelConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
reward.to(config.device)
reward.eval()
return reward
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
# Create base kwargs
kwargs = {"strict": strict}
# Add device parameter for newer versions that support it
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
kwargs["device"] = map_location
# Load the model with appropriate kwargs
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
if missing_keys:
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
# For older versions, manually move to device if needed
if "device" not in kwargs and map_location != "cpu":
logging.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location)
return model
def get_optim_params(self):
"""
Returns the reward-model-specific parameters dict to be passed on to the optimizer.
"""
return self.parameters()
def reset(self) -> None:
"""Reset any internal state."""
pass
@abc.abstractmethod
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute a scalar reward signal for a batch of observations.
Args:
batch: Dictionary containing at minimum observation tensors.
May also contain "action", "next_observation.*", etc.
Returns:
Tensor of shape ``(batch_size,)`` with reward values.
"""
...
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Training forward pass — override for trainable reward models."""
raise NotImplementedError(
f"{self.__class__.__name__} is not trainable. Only use compute_reward() for inference."
)
@property
def is_trainable(self) -> bool:
"""Whether this reward model can be trained via ``lerobot-train``.
Trainable reward models override :meth:`forward`; zero-shot models
inherit the base implementation that raises ``NotImplementedError``.
"""
return type(self).forward is not PreTrainedRewardModel.forward
def push_model_to_hub(self, cfg: "TrainPipelineConfig"):
api = HfApi()
repo_id = api.create_repo(
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
).repo_id
# Push the files to the repo in a single commit
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
)
card.save(str(saved_path / "README.md"))
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
commit_info = api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message="Upload reward model weights, train config and readme",
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
ignore_patterns=["*.tmp", "*.log"],
)
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
) -> ModelCard:
card_data = ModelCardData(
license=license or "apache-2.0",
library_name="lerobot",
pipeline_tag="robotics",
tags=list(set(tags or []).union({"robotics", "lerobot", "reward-model", model_type})),
model_name=model_type,
datasets=dataset_repo_id,
)
template_card = (
files("lerobot.templates")
.joinpath("lerobot_rewardmodel_modelcard_template.md")
.read_text(encoding="utf-8")
)
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card

View File

@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,5 +14,6 @@
from .configuration_sarm import SARMConfig from .configuration_sarm import SARMConfig
from .modeling_sarm import SARMRewardModel from .modeling_sarm import SARMRewardModel
from .processor_sarm import make_sarm_pre_post_processors
__all__ = ["SARMConfig", "SARMRewardModel"] __all__ = ["SARMConfig", "SARMRewardModel", "make_sarm_pre_post_processors"]

View File

@@ -25,18 +25,18 @@ need ~num_frames/30 queries instead of one per frame (~30x speedup).
Usage: Usage:
# Full RA-BC computation with visualizations # Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\ python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 --reward-model-path <USER>/sarm_single_uni4
# Faster computation with stride (compute every 5 frames, interpolate the rest) # Faster computation with stride (compute every 5 frames, interpolate the rest)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\ python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\ --reward-model-path <USER>/sarm_single_uni4 \\
--stride 5 --stride 5
# Visualize predictions only (no RA-BC computation) # Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\ python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\ --reward-model-path <USER>/sarm_single_uni4 \\
--visualize-only \\ --visualize-only \\
@@ -58,10 +58,9 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from lerobot.datasets import LeRobotDataset from lerobot.datasets import LeRobotDataset
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
from .modeling_sarm import SARMRewardModel from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
from .processor_sarm import make_sarm_pre_post_processors from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
from .sarm_utils import normalize_stage_tau
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None: def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
@@ -713,12 +712,12 @@ def main():
epilog=""" epilog="""
Examples: Examples:
# Full RA-BC computation with visualizations # Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\ python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 --reward-model-path <USER>/sarm_single_uni4
# Visualize predictions only (no RA-BC computation) # Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\ python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\ --reward-model-path <USER>/sarm_single_uni4 \\
--visualize-only \\ --visualize-only \\

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu # Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
# and The HuggingFace Inc. team. All rights reserved. # and The HuggingFace Inc. team. All rights reserved.
# #
@@ -22,14 +20,15 @@ Paper: https://arxiv.org/abs/2509.25358
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
@PreTrainedConfig.register_subclass("sarm") @RewardModelConfig.register_subclass("sarm")
@dataclass @dataclass
class SARMConfig(PreTrainedConfig): class SARMConfig(RewardModelConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling). """Configuration class for SARM (Stage-Aware Reward Modeling).
Supports three annotation modes: Supports three annotation modes:
@@ -110,7 +109,6 @@ class SARMConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]: if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
raise ValueError( raise ValueError(
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}" f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu # Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
# and The HuggingFace Inc. team. All rights reserved. # and The HuggingFace Inc. team. All rights reserved.
# #
@@ -34,14 +32,13 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor from torch import Tensor
from lerobot.utils.constants import OBS_STR from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
from ..pretrained import PreTrainedPolicy from lerobot.rewards.sarm.sarm_utils import (
from .configuration_sarm import SARMConfig
from .sarm_utils import (
normalize_stage_tau, normalize_stage_tau,
pad_state_to_max_dim, pad_state_to_max_dim,
) )
from lerobot.utils.constants import OBS_STR
class StageTransformer(nn.Module): class StageTransformer(nn.Module):
@@ -353,7 +350,7 @@ def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
return stage_onehot return stage_onehot
class SARMRewardModel(PreTrainedPolicy): class SARMRewardModel(PreTrainedRewardModel):
""" """
SARM Reward Model for stage-aware task completion rewards. SARM Reward Model for stage-aware task completion rewards.
@@ -471,6 +468,23 @@ class SARMRewardModel(PreTrainedPolicy):
self.subtask_model.to(device) self.subtask_model.to(device)
return self return self
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute dense progress reward in [0, 1] from batch.
Expects batch to contain:
- "observation_features" or video embeddings: (B, T, 512)
- "language_embedding" or text embeddings: (B, 512)
- optionally "observation.state": (B, T, state_dim)
"""
text_emb = batch.get("language_embedding", batch.get("text_features"))
video_emb = batch.get("observation_features", batch.get("video_features"))
state = batch.get("observation.state", batch.get("state_features"))
rewards = self.calculate_rewards(text_emb, video_emb, state)
if isinstance(rewards, np.ndarray):
rewards = torch.from_numpy(rewards).float()
return rewards
@torch.no_grad() @torch.no_grad()
def calculate_rewards( def calculate_rewards(
self, self,
@@ -631,17 +645,9 @@ class SARMRewardModel(PreTrainedPolicy):
return self.parameters() return self.parameters()
def reset(self): def reset(self):
"""Required by PreTrainedPolicy but not used for reward models.""" """SARM has no episode-level state to reset."""
pass pass
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for reward models."""
raise NotImplementedError("SARM model does not predict action chunks")
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _train_step( def _train_step(
self, self,
img_emb: torch.Tensor, # (B, N, T, D) img_emb: torch.Tensor, # (B, N, T, D)

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,16 +58,15 @@ from lerobot.processor import (
policy_action_to_transition, policy_action_to_transition,
transition_to_policy_action, transition_to_policy_action,
) )
from lerobot.types import EnvTransition, PolicyAction, TransitionKey from lerobot.rewards.sarm.configuration_sarm import SARMConfig
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.rewards.sarm.sarm_utils import (
from .configuration_sarm import SARMConfig
from .sarm_utils import (
apply_rewind_augmentation, apply_rewind_augmentation,
compute_absolute_indices, compute_absolute_indices,
find_stage_and_tau, find_stage_and_tau,
pad_state_to_max_dim, pad_state_to_max_dim,
) )
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep): class SARMEncodingProcessorStep(ProcessorStep):
@@ -455,7 +452,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings # Get image embeddings
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu() # transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
output = self.clip_model.get_image_features(**inputs)
if not isinstance(output, torch.Tensor):
output = output.pooler_output
if output is None:
raise ValueError("pooler_output should not be None for CLIP models.")
embeddings = output.detach().cpu()
# Handle single frame case # Handle single frame case
if embeddings.dim() == 1: if embeddings.dim() == 1:
@@ -482,7 +485,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True) inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() # transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
output = self.clip_model.get_text_features(**inputs)
if not isinstance(output, torch.Tensor):
output = output.pooler_output
if output is None:
raise ValueError("pooler_output should not be None for CLIP models.")
text_embedding = output.detach().cpu()
text_embedding = text_embedding.expand(batch_size, -1) text_embedding = text_embedding.expand(batch_size, -1)
return text_embedding return text_embedding

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,14 +12,38 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
This module implements the SampleWeighter protocol for RA-BC training,
which weights training samples based on their task progress as measured
by the SARM reward model.
The weights are computed based on progress deltas:
delta = progress[t + chunk_size] - progress[t]
High-quality samples (positive progress) get higher weights, while
samples with negative progress (going backwards) get zero weight.
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
"""
import logging import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np import numpy as np
import pandas as pd
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from lerobot.utils.import_utils import _pandas_available
from lerobot.utils.sample_weighting import SampleWeighter
if TYPE_CHECKING or _pandas_available:
import pandas as pd
else:
pd = None # type: ignore[assignment]
def resolve_hf_path(path: str | Path) -> Path: def resolve_hf_path(path: str | Path) -> Path:
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path.""" """Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
@@ -34,23 +56,27 @@ def resolve_hf_path(path: str | Path) -> Path:
return Path(path) return Path(path)
class RABCWeights: class RABCWeights(SampleWeighter):
""" """
Load precomputed SARM progress values and compute RA-BC weights during training. Load precomputed SARM progress values and compute RA-BC weights during training.
This class implements the SampleWeighter ABC for use with the generic
sample weighting infrastructure in lerobot.
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py). Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
During training, computes: During training, computes:
- progress_delta = progress[t + chunk_size] - progress[t] - progress_delta = progress[t + chunk_size] - progress[t]
- rabc_weight based on the delta (paper Eq. 8-9) - rabc_weight based on the delta (paper Eq. 8-9)
Args: Args:
progress_path: Path to parquet file with precomputed progress values progress_path: Path to parquet file with precomputed progress values.
chunk_size: Number of frames ahead for computing progress delta Supports HuggingFace URLs (hf://datasets/...).
head_mode: Which SARM head to use ("sparse" or "dense") chunk_size: Number of frames ahead for computing progress delta.
kappa: Hard threshold for high-quality samples (default: 0.01) head_mode: Which SARM head to use ("sparse" or "dense").
epsilon: Small constant for numerical stability (default: 1e-6) kappa: Hard threshold for high-quality samples (default: 0.01).
fallback_weight: Weight to use for frames without valid delta (default: 1.0) epsilon: Small constant for numerical stability (default: 1e-6).
device: Device to return tensors on fallback_weight: Weight to use for frames without valid delta (default: 1.0).
device: Device to return tensors on.
""" """
def __init__( def __init__(
@@ -61,7 +87,7 @@ class RABCWeights:
kappa: float = 0.01, kappa: float = 0.01,
epsilon: float = 1e-6, epsilon: float = 1e-6,
fallback_weight: float = 1.0, fallback_weight: float = 1.0,
device: torch.device = None, device: torch.device | None = None,
): ):
self.progress_path = resolve_hf_path(progress_path) self.progress_path = resolve_hf_path(progress_path)
self.chunk_size = chunk_size self.chunk_size = chunk_size
@@ -87,8 +113,8 @@ class RABCWeights:
logging.info(f"Using progress column: {self.progress_column}") logging.info(f"Using progress column: {self.progress_column}")
self.progress_lookup = {} self.progress_lookup: dict[int, float] = {}
self.episode_lookup = {} self.episode_lookup: dict[int, int] = {}
for _, row in self.df.iterrows(): for _, row in self.df.iterrows():
global_idx = int(row["index"]) global_idx = int(row["index"])
@@ -100,7 +126,7 @@ class RABCWeights:
self.episode_lookup[global_idx] = episode_idx self.episode_lookup[global_idx] = episode_idx
# Build episode boundaries for delta computation # Build episode boundaries for delta computation
self.episode_boundaries = {} self.episode_boundaries: dict[int, dict[str, int]] = {}
for episode_idx in self.df["episode_index"].unique(): for episode_idx in self.df["episode_index"].unique():
ep_df = self.df[self.df["episode_index"] == episode_idx] ep_df = self.df[self.df["episode_index"] == episode_idx]
self.episode_boundaries[int(episode_idx)] = { self.episode_boundaries[int(episode_idx)] = {
@@ -114,7 +140,7 @@ class RABCWeights:
# Compute global statistics for weight computation # Compute global statistics for weight computation
self._compute_global_stats() self._compute_global_stats()
def _compute_global_stats(self): def _compute_global_stats(self) -> None:
"""Compute global mean and std of progress deltas for weight calculation.""" """Compute global mean and std of progress deltas for weight calculation."""
all_deltas = [] all_deltas = []
@@ -138,8 +164,8 @@ class RABCWeights:
all_deltas.append(delta) all_deltas.append(delta)
if all_deltas: if all_deltas:
self.delta_mean = max(np.mean(all_deltas), 0.0) self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
self.delta_std = max(np.std(all_deltas), self.epsilon) self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}") logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
else: else:
self.delta_mean = 0.0 self.delta_mean = 0.0
@@ -157,18 +183,19 @@ class RABCWeights:
4. Compute weight using paper Eq. 8-9 4. Compute weight using paper Eq. 8-9
Args: Args:
batch: Training batch containing "index" key with global frame indices batch: Training batch containing "index" key with global frame indices.
Returns: Returns:
Tuple of: Tuple of:
- Weights tensor (batch_size,) normalized to sum to batch_size - Weights tensor (batch_size,) normalized to sum to batch_size.
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight - Stats dict with weighting statistics for logging.
""" """
indices = batch.get("index") indices = batch.get("index")
if indices is None: if indices is None:
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights") logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
batch_size = self._get_batch_size(batch) batch_size = self._get_batch_size(batch)
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0} stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
return torch.ones(batch_size, device=self.device), stats
# Convert to list of ints # Convert to list of ints
if isinstance(indices, torch.Tensor): if isinstance(indices, torch.Tensor):
@@ -183,29 +210,29 @@ class RABCWeights:
delta = self._compute_delta(idx) delta = self._compute_delta(idx)
deltas.append(delta) deltas.append(delta)
deltas = np.array(deltas, dtype=np.float32) deltas_array = np.array(deltas, dtype=np.float32)
# Compute weights from deltas # Compute weights from deltas
weights = self._compute_weights(deltas) weights = self._compute_weights(deltas_array)
# Compute stats before normalization for logging # Compute stats before normalization for logging
raw_mean_weight = float(np.nanmean(weights)) raw_mean_weight = float(np.nanmean(weights))
num_zero_weight = int(np.sum(weights == 0)) num_zero_weight = int(np.sum(weights == 0))
num_full_weight = int(np.sum(weights == 1.0)) num_full_weight = int(np.sum(weights == 1.0))
batch_stats = { batch_stats = {
"raw_mean_weight": raw_mean_weight, "mean_weight": raw_mean_weight,
"num_zero_weight": num_zero_weight, "num_zero_weight": num_zero_weight,
"num_full_weight": num_full_weight, "num_full_weight": num_full_weight,
} }
weights = torch.tensor(weights, device=self.device, dtype=torch.float32) weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
# Normalize to sum to batch_size # Normalize to sum to batch_size
batch_size = len(weights) batch_size = len(weights_tensor)
weight_sum = weights.sum() + self.epsilon weight_sum = weights_tensor.sum() + self.epsilon
weights = weights * batch_size / weight_sum weights_tensor = weights_tensor * batch_size / weight_sum
return weights, batch_stats return weights_tensor, batch_stats
def _compute_delta(self, global_idx: int) -> float: def _compute_delta(self, global_idx: int) -> float:
"""Compute progress delta for a single frame.""" """Compute progress delta for a single frame."""
@@ -241,7 +268,7 @@ class RABCWeights:
- Final weight: wi = 1{ri > κ} + 1{0 ri κ}˜wi - Final weight: wi = 1{ri > κ} + 1{0 ri κ}˜wi
Returns: Returns:
Array of weights Array of weights.
""" """
valid_mask = ~np.isnan(deltas) valid_mask = ~np.isnan(deltas)
@@ -273,12 +300,13 @@ class RABCWeights:
if key in batch: if key in batch:
val = batch[key] val = batch[key]
if isinstance(val, (torch.Tensor, np.ndarray)): if isinstance(val, (torch.Tensor, np.ndarray)):
return val.shape[0] return int(val.shape[0])
return 1 return 1
def get_stats(self) -> dict: def get_stats(self) -> dict:
"""Get statistics.""" """Get global statistics about the RA-BC weighting."""
return { return {
"type": "rabc",
"num_frames": len(self.progress_lookup), "num_frames": len(self.progress_lookup),
"chunk_size": self.chunk_size, "chunk_size": self.chunk_size,
"head_mode": self.head_mode, "head_mode": self.head_mode,

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -76,6 +76,7 @@ from lerobot.transport.utils import (
) )
from lerobot.types import TransitionKey from lerobot.types import TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import ( from lerobot.utils.transition import (
@@ -94,7 +95,6 @@ from .gym_manipulator import (
make_robot_env, make_robot_env,
step_env_and_process_transition, step_env_and_process_transition,
) )
from .process import ProcessSignalHandler
from .queue import get_last_item_from_queue from .queue import get_last_item_from_queue
# Main entry point # Main entry point

View File

@@ -193,15 +193,15 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
fps=int(original_dataset.fps), fps=int(original_dataset.fps),
root=new_dataset_root, root=new_dataset_root,
robot_type=original_dataset.meta.robot_type, robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"], features=original_dataset.meta.info.features,
use_videos=len(original_dataset.meta.video_keys) > 0, use_videos=len(original_dataset.meta.video_keys) > 0,
) )
# Update the metadata for every image key that will be cropped: # Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.) # (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict: for key in crop_params_dict:
if key in new_dataset.meta.info["features"]: if key in new_dataset.meta.info.features:
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) new_dataset.meta.info.features[key]["shape"] = (3, *resize_size)
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset # TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
prev_episode_index = 0 prev_episode_index = 0

View File

@@ -90,6 +90,7 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR, TRAINING_STATE_DIR,
) )
from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import ( from lerobot.utils.utils import (
@@ -99,7 +100,6 @@ from lerobot.utils.utils import (
from .buffer import ReplayBuffer, concatenate_batch_transitions from .buffer import ReplayBuffer, concatenate_batch_transitions
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .process import ProcessSignalHandler
@parser.wrap() @parser.wrap()

View File

@@ -0,0 +1,87 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Policy deployment engine with pluggable rollout strategies."""
from lerobot.utils.import_utils import require_package
require_package("datasets", extra="dataset")
from .configs import (
BaseStrategyConfig,
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutStrategyConfig,
SentryStrategyConfig,
)
from .context import (
DatasetContext,
HardwareContext,
PolicyContext,
ProcessorContext,
RolloutContext,
RuntimeContext,
build_rollout_context,
)
from .inference import (
InferenceEngine,
InferenceEngineConfig,
RTCInferenceConfig,
RTCInferenceEngine,
SyncInferenceConfig,
SyncInferenceEngine,
create_inference_engine,
)
from .strategies import (
BaseStrategy,
DAggerStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
create_strategy,
)
__all__ = [
"BaseStrategy",
"BaseStrategyConfig",
"DAggerKeyboardConfig",
"DAggerPedalConfig",
"DAggerStrategy",
"DAggerStrategyConfig",
"DatasetContext",
"HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
"ProcessorContext",
"RTCInferenceConfig",
"RTCInferenceEngine",
"RolloutConfig",
"RolloutContext",
"RolloutStrategy",
"RolloutStrategyConfig",
"RuntimeContext",
"SentryStrategy",
"SentryStrategyConfig",
"SyncInferenceConfig",
"SyncInferenceEngine",
"build_rollout_context",
"create_inference_engine",
"create_strategy",
]

View File

@@ -0,0 +1,323 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration dataclasses for the rollout deployment engine."""
from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
import draccus
from lerobot.configs import PreTrainedConfig, parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.robots.config import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from .inference import InferenceEngineConfig, SyncInferenceConfig
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Strategy configs (polymorphic dispatch via draccus ChoiceRegistry)
# ---------------------------------------------------------------------------
@dataclass
class RolloutStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
"""Abstract base for rollout strategy configurations.
Use ``--strategy.type=<name>`` on the CLI to select a strategy.
"""
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@RolloutStrategyConfig.register_subclass("base")
@dataclass
class BaseStrategyConfig(RolloutStrategyConfig):
"""Autonomous rollout with no data recording."""
pass
@RolloutStrategyConfig.register_subclass("sentry")
@dataclass
class SentryStrategyConfig(RolloutStrategyConfig):
"""Continuous autonomous rollout with always-on recording.
Episode duration is derived from camera resolution, FPS, and
``target_video_file_size_mb`` so that each saved episode produces a
video file that has crossed the target size. This aligns episode
boundaries with the dataset's video file chunking, so each
``push_to_hub`` call uploads complete video files rather than
re-uploading a growing file that hasn't crossed the chunk boundary.
"""
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation. Episodes are
# saved once the estimated video duration would exceed this limit.
# Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None.
target_video_file_size_mb: int | None = None
@RolloutStrategyConfig.register_subclass("highlight")
@dataclass
class HighlightStrategyConfig(RolloutStrategyConfig):
"""Autonomous rollout with on-demand recording via ring buffer.
A memory-bounded ring buffer continuously captures telemetry. When
the user presses the save key, the buffer contents are flushed to
the dataset and live recording continues until the key is pressed
again.
"""
ring_buffer_seconds: float = 10.0
ring_buffer_max_memory_mb: int = 1024
save_key: str = "s"
push_key: str = "h"
@dataclass
class DAggerKeyboardConfig:
"""Keyboard key bindings for DAgger controls.
Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or
special key names (``"space"``).
"""
pause_resume: str = "space"
correction: str = "tab"
upload: str = "enter"
@dataclass
class DAggerPedalConfig:
"""Foot pedal configuration for DAgger controls.
Pedal codes are evdev key code strings (e.g. ``"KEY_A"``).
"""
device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
pause_resume: str = "KEY_A"
correction: str = "KEY_B"
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
"""Human-in-the-loop data collection (DAgger / RaC).
Alternates between autonomous policy execution and human intervention.
Intervention frames are tagged with ``intervention=True``.
Input is controlled via either a keyboard or foot pedal, selected by
``input_device``. Each device exposes three actions:
1. **pause_resume** — toggle policy execution on/off.
2. **correction** — toggle human correction recording.
3. **upload** — push dataset to hub on demand (corrections-only mode).
When ``record_autonomous=False`` (default) only human-correction windows
are recorded — each correction becomes its own episode. Set to ``True``
to record both autonomous and correction frames with size-based episode
rotation (same as Sentry) and background uploading. ``push_to_hub`` is
blocked while a correction is in progress.
"""
# Number of correction episodes to collect (corrections-only mode).
# When None, falls back to ``--dataset.num_episodes``.
num_episodes: int | None = None
record_autonomous: bool = False
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation (record_autonomous
# mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None.
target_video_file_size_mb: int | None = None
input_device: str = "keyboard"
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
def __post_init__(self):
if self.input_device not in ("keyboard", "pedal"):
raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'")
# ---------------------------------------------------------------------------
# Top-level rollout config
# ---------------------------------------------------------------------------
@dataclass
class RolloutConfig:
"""Top-level configuration for the ``lerobot-rollout`` CLI.
Combines hardware, policy, strategy, and runtime settings. The
``__post_init__`` method performs fail-fast validation to reject
invalid flag combinations early.
"""
# Hardware
robot: RobotConfig | None = None
teleop: TeleoperatorConfig | None = None
# Policy (loaded from --policy.path via __post_init__)
policy: PreTrainedConfig | None = None
# Strategy (polymorphic: --strategy.type=base|sentry|highlight|dagger)
strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig)
# Inference backend (polymorphic: --inference.type=sync|rtc)
inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig)
# Dataset (required for sentry, highlight, dagger; None for base)
dataset: DatasetRecordConfig | None = None
# Runtime
fps: float = 30.0
duration: float = 0.0 # 0 = infinite (24/7 mode)
interpolation_multiplier: int = 1
device: str | None = None
task: str = ""
display_data: bool = False
# Display data on a remote Rerun server
display_ip: str | None = None
# Port of the remote Rerun server
display_port: int | None = None
# Whether to display compressed images in Rerun
display_compressed_images: bool = False
# Use vocal synthesis to read events
play_sounds: bool = True
resume: bool = False
# Rename map for mapping robot/dataset observation keys to policy keys
rename_map: dict[str, str] = field(default_factory=dict)
# Hardware teardown
# When True (default), smoothly interpolate the robot back to the joint
# positions captured at startup before disconnecting. Set to False to
# leave the robot in its final achieved pose at shutdown.
return_to_initial_position: bool = True
# Torch compile
use_torch_compile: bool = False
torch_compile_backend: str = "inductor"
torch_compile_mode: str = "default"
compile_warmup_inferences: int = 2
def __post_init__(self):
"""Validate config invariants and load the policy config from ``--policy.path``."""
# --- Strategy-specific validation ---
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
raise ValueError("DAgger strategy requires --teleop.type to be set")
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
needs_dataset = isinstance(
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
)
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None:
raise ValueError(
"Base strategy does not record data. Use sentry, highlight, or dagger for recording."
)
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
if (
isinstance(self.strategy, SentryStrategyConfig)
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("Sentry mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# Highlight writes frames while the policy is still running, so streaming is mandatory.
if (
isinstance(self.strategy, HighlightStrategyConfig)
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("Highlight mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# DAgger: streaming is mandatory only when the autonomous phase is also recorded.
if isinstance(self.strategy, DAggerStrategyConfig) and self.dataset is not None:
if self.strategy.record_autonomous and not self.dataset.streaming_encoding:
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
self.dataset.streaming_encoding = True
elif not self.strategy.record_autonomous and not self.dataset.streaming_encoding:
logger.info(
"Streaming encoding is disabled for DAgger corrections-only mode. "
"Consider enabling it for faster episode saving: "
"--dataset.streaming_encoding=true --dataset.encoder_threads=2"
)
# DAgger: resolve num_episodes from dataset config when not explicitly set.
if isinstance(self.strategy, DAggerStrategyConfig) and self.strategy.num_episodes is None:
if self.dataset is not None:
self.strategy.num_episodes = self.dataset.num_episodes
logger.info(
"DAgger num_episodes not set — using --dataset.num_episodes=%d",
self.strategy.num_episodes,
)
else:
raise ValueError(
"DAgger num_episodes must be set either via --strategy.num_episodes or --dataset.num_episodes"
)
# --- Policy loading ---
if self.robot is None:
raise ValueError("--robot.type is required for rollout")
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("--policy.path is required for rollout")
# --- Task resolution ---
# When any --dataset.* flag is passed, draccus creates a DatasetRecordConfig with single_task="".
# If the user set the task via the top-level --task flag, propagate it so that all
# downstream consumers (inference engine, dataset frame builders) see it.
if self.dataset is not None and not self.dataset.single_task and self.task:
logger.info("Propagating top-level task '%s' to dataset config", self.task)
self.dataset.single_task = self.task
elif self.dataset is not None and self.dataset.single_task and not self.task:
logger.info("Propagating dataset single_task '%s' to top-level task", self.dataset.single_task)
self.task = self.dataset.single_task
# --- Device resolution ---
# Resolve device from the policy config when not explicitly set so all
# components (policy.to, preprocessor, inference engine) use the same
# device string instead of inconsistent fallbacks.
if self.device is None or not is_torch_device_available(self.device):
resolved = self.policy.device
if resolved:
self.device = resolved
logger.info("Resolved device from policy config: %s", self.device)
else:
self.device = auto_select_torch_device().type
logger.info("No policy config to resolve device from; auto-selected device: %s", self.device)
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]

View File

@@ -0,0 +1,454 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout context: shared state created once before strategy dispatch.
Grouped into five topical sub-contexts — :class:`RuntimeContext`,
:class:`HardwareContext`, :class:`PolicyContext`, :class:`ProcessorContext`,
and :class:`DatasetContext` — assembled into :class:`RolloutContext`.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from threading import Event
import torch
from lerobot.configs import FeatureType
from lerobot.datasets import (
LeRobotDataset,
aggregate_pipeline_dataset_features,
create_initial_features,
)
from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import (
PolicyProcessorPipeline,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_processors,
rename_stats,
)
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep
from lerobot.robots import make_robot_from_config
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
from .inference import (
InferenceEngine,
RTCInferenceConfig,
SyncInferenceConfig,
create_inference_engine,
)
from .robot_wrapper import ThreadSafeRobot
logger = logging.getLogger(__name__)
def _resolve_action_key_order(
policy_action_names: list[str] | None, dataset_action_names: list[str]
) -> list[str]:
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
if not policy_action_names:
return dataset_action_names
policy_action_names = list(policy_action_names)
if len(policy_action_names) != len(dataset_action_names):
logger.warning(
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
len(policy_action_names),
len(dataset_action_names),
)
return dataset_action_names
if set(dataset_action_names) != set(policy_action_names):
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
return dataset_action_names
return policy_action_names
# ---------------------------------------------------------------------------
# Sub-contexts
# ---------------------------------------------------------------------------
@dataclass
class RuntimeContext:
"""Runtime knobs shared with every strategy."""
cfg: RolloutConfig
shutdown_event: Event
@dataclass
class HardwareContext:
"""Connected hardware.
The raw robot is available via ``robot_wrapper.inner`` when needed
(e.g. for disconnect); strategies should otherwise go through the
thread-safe wrapper.
``initial_position`` stores the robot's joint positions at connect
time. Strategies use it to return the robot to a safe pose before
shutting down.
"""
robot_wrapper: ThreadSafeRobot
teleop: Teleoperator | None
initial_position: dict | None = None
@dataclass
class PolicyContext:
"""Loaded policy and its inference engine."""
policy: PreTrainedPolicy
preprocessor: PolicyProcessorPipeline
postprocessor: PolicyProcessorPipeline
inference: InferenceEngine
@dataclass
class ProcessorContext:
"""Robot-side pipelines (run outside the policy)."""
teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation]
@dataclass
class DatasetContext:
"""Dataset and feature bookkeeping."""
dataset: LeRobotDataset | None
dataset_features: dict = field(default_factory=dict)
hw_features: dict = field(default_factory=dict)
ordered_action_keys: list[str] = field(default_factory=list)
@dataclass
class RolloutContext:
"""Bundle of sub-contexts passed to every rollout strategy.
Built once by :func:`build_rollout_context` before strategy dispatch.
"""
runtime: RuntimeContext
hardware: HardwareContext
policy: PolicyContext
processors: ProcessorContext
data: DatasetContext
# ---------------------------------------------------------------------------
# Build
# ---------------------------------------------------------------------------
def build_rollout_context(
cfg: RolloutConfig,
shutdown_event: Event,
teleop_action_processor: RobotProcessorPipeline | None = None,
robot_action_processor: RobotProcessorPipeline | None = None,
robot_observation_processor: RobotProcessorPipeline | None = None,
) -> RolloutContext:
"""Wire up policy, processors, hardware, dataset, and inference engine.
The order is policy-first / hardware-last so a bad ``--policy.path``
fails fast without touching the robot.
"""
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
policy_config = cfg.policy
policy_class = get_policy_class(policy_config.type)
if hasattr(policy_config, "compile_model"):
policy_config.compile_model = cfg.use_torch_compile
if policy_config.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
if policy_config.use_peft:
from peft import PeftConfig, PeftModel
peft_path = policy_config.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
)
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
else:
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
if is_rtc:
policy.config.rtc_config = cfg.inference.rtc
if hasattr(policy, "init_rtc_processor"):
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
try:
if hasattr(torch, "compile"):
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
"options": {"triton.cudagraphs": False},
}
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
logger.info("torch.compile applied to predict_action_chunk")
except Exception as e:
logger.warning("Failed to apply torch.compile: %s", e)
# --- 2. Robot-side processors (user-supplied or defaults) --------
if (
teleop_action_processor is None
or robot_action_processor is None
or robot_observation_processor is None
):
_t, _r, _o = make_default_processors()
teleop_action_processor = teleop_action_processor or _t
robot_action_processor = robot_action_processor or _r
robot_observation_processor = robot_observation_processor or _o
# --- 3. Hardware (heaviest side-effect, deferred) -----------------
logger.info("Connecting robot (%s)...", cfg.robot.type if cfg.robot else "?")
robot = make_robot_from_config(cfg.robot)
robot.connect()
logger.info("Robot connected: %s", robot.name)
# Store the initial joint positions so we can return to a safe pose on shutdown.
initial_obs = robot.get_observation()
initial_position = {k: v for k, v in initial_obs.items() if k.endswith(".pos")}
logger.info("Captured initial robot position (%d keys)", len(initial_position))
robot_wrapper = ThreadSafeRobot(robot)
teleop = None
if cfg.teleop is not None:
logger.info("Connecting teleoperator (%s)...", cfg.teleop.type if cfg.teleop else "?")
teleop = make_teleoperator_from_config(cfg.teleop)
teleop.connect()
logger.info("Teleoperator connected")
# TODO(Steven): once Teleoperator motor-control methods are standardised
# (``enable_torque`` / ``disable_torque`` / ``write_goal_positions``), gate
# the DAgger strategy on their presence here and fail fast with a helpful
# message instead of relying on the operator to pre-align the leader by
# hand. See :func:`DAggerStrategy._apply_transition` for the matching
# disabled call sites.
# if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
# required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
# missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
# if missing:
# teleop.disconnect()
# raise ValueError(
# f"DAgger strategy requires a teleoperator with motor control methods "
# f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
# )
# --- 4. Features + action-key reconciliation ---------------------
# TODO(Steven):Only ``.pos`` joint features are routed to the policy as state and as the
# action target; velocity and torque channels (when present) are kept in
# the raw observation but excluded from the policy-facing tensors.
all_obs_features = robot.observation_features
# ``observation_features`` values are either a tuple (camera shape) or the
# ``float`` type itself used as a sentinel for scalar motor features —
# see ``dict[str, type | tuple]`` annotation on ``Robot.observation_features``.
observation_features_hw = {
k: v
for k, v in all_obs_features.items()
if isinstance(v, tuple) or (v is float and k.endswith(".pos"))
}
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
# The action side is always needed: sync inference reads action names from
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
action_dataset_features = aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
)
# Observation-side aggregation is needed because of build_dataset_frame
observation_dataset_features = aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=observation_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
)
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
raw_action_keys = list(action_features_hw.keys())
policy_action_names = getattr(policy_config, "action_feature_names", None)
ordered_action_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
raw_action_keys,
)
# Validate visual features if no rename_map is active
rename_map = cfg.rename_map
if not rename_map:
expected_visuals = {
k for k, v in policy_config.input_features.items() if v.type == FeatureType.VISUAL
}
provided_visuals = {
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
}
policy_subset = expected_visuals.issubset(provided_visuals)
hw_subset = provided_visuals.issubset(expected_visuals)
if not (policy_subset or hw_subset):
raise ValueError(
f"Visual feature mismatch between policy and robot hardware.\n"
f"Policy expects: {expected_visuals}\n"
f"Robot provides: {provided_visuals}"
)
# --- 5. Dataset -------------
dataset = None
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
logger.info("Setting up dataset (repo_id=%s)...", cfg.dataset.repo_id)
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
)
else:
if isinstance(cfg.strategy, DAggerStrategyConfig):
dataset_features["intervention"] = {
"dtype": "bool",
"shape": (1,),
"names": None,
}
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
if not repo_name.startswith("rollout_"):
raise ValueError(
"Dataset names for rollout must start with 'rollout_'. "
"Use --dataset.repo_id=<user>/rollout_<name> for policy deployment datasets."
)
cfg.dataset.stamp_repo_id()
target_video_mb = getattr(cfg.strategy, "target_video_file_size_mb", None)
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
video_files_size_in_mb=target_video_mb,
)
if dataset is not None:
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
dataset_stats = None
if dataset is not None:
dataset_stats = rename_stats(
dataset.meta.stats,
cfg.rename_map,
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy_config,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset_stats,
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.rename_map},
},
)
if isinstance(cfg.inference, SyncInferenceConfig) and any(
isinstance(step, RelativeActionsProcessorStep) and step.enabled
for step in getattr(preprocessor, "steps", ())
):
raise NotImplementedError(
"SyncInferenceEngine does not support policies with relative actions for now."
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
)
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
logger.info(
"Creating inference engine (type=%s)...",
cfg.inference.type if hasattr(cfg.inference, "type") else "sync",
)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
inference_strategy = create_inference_engine(
cfg.inference,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
robot_wrapper=robot_wrapper,
hw_features=hw_features,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task=task_str,
fps=cfg.fps,
device=cfg.device,
use_torch_compile=cfg.use_torch_compile,
compile_warmup_inferences=cfg.compile_warmup_inferences,
shutdown_event=shutdown_event,
)
# --- 8. Assemble ---------------------------------------------------
logger.info("Rollout context assembled successfully")
return RolloutContext(
runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event),
hardware=HardwareContext(
robot_wrapper=robot_wrapper, teleop=teleop, initial_position=initial_position
),
policy=PolicyContext(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
inference=inference_strategy,
),
processors=ProcessorContext(
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
),
data=DatasetContext(
dataset=dataset,
dataset_features=dataset_features,
hw_features=hw_features,
ordered_action_keys=ordered_action_keys,
),
)

View File

@@ -0,0 +1,39 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine package — backend-agnostic action production.
Concrete backends (``sync``, ``rtc``, ...) expose the same small interface so
rollout strategies never branch on which backend is in use.
"""
from .base import InferenceEngine
from .factory import (
InferenceEngineConfig,
RTCInferenceConfig,
SyncInferenceConfig,
create_inference_engine,
)
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
__all__ = [
"InferenceEngine",
"InferenceEngineConfig",
"RTCInferenceConfig",
"RTCInferenceEngine",
"SyncInferenceConfig",
"SyncInferenceEngine",
"create_inference_engine",
]

View File

@@ -0,0 +1,89 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine ABC.
Rollout strategies consume actions through this small interface so they
do not need to know whether inference happens inline on the control thread
or asynchronously in a background thread (RTC).
"""
from __future__ import annotations
import abc
import torch
class InferenceEngine(abc.ABC):
"""Abstract backend for producing actions during rollout.
Subclasses decide whether inference happens inline on the control
thread or asynchronously in a background thread. The contract is
minimal so additional backends can be plugged in without touching
rollout strategies.
Lifecycle
---------
``start`` — prepare the backend (e.g. launch a background thread).
``stop`` — shut the backend down cleanly.
``reset`` — clear episode-scoped state (policy hidden state, queues…).
Action production
-----------------
``get_action(obs_frame)`` — return the next action tensor, or
``None`` if none is available (e.g. async queue empty). Sync
backends always compute from ``obs_frame``; async backends ignore
it (they receive observations via ``notify_observation``).
Optional hooks
--------------
``notify_observation`` / ``pause`` / ``resume`` have a no-op default
so rollout strategies can invoke them unconditionally.
"""
@abc.abstractmethod
def start(self) -> None:
"""Initialise the backend."""
@abc.abstractmethod
def stop(self) -> None:
"""Tear the backend down."""
@abc.abstractmethod
def reset(self) -> None:
"""Clear episode-scoped state."""
@abc.abstractmethod
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Return the next action tensor, or ``None`` if unavailable."""
def notify_observation(self, obs: dict) -> None: # noqa: B027
"""Publish the latest processed observation. Default: no-op."""
def pause(self) -> None: # noqa: B027
"""Pause background inference. Default: no-op."""
def resume(self) -> None: # noqa: B027
"""Resume background inference. Default: no-op."""
@property
def ready(self) -> bool:
"""True once the backend can produce actions (e.g. warmup done)."""
return True
@property
def failed(self) -> bool:
"""True if an unrecoverable error occurred in the backend."""
return False

View File

@@ -0,0 +1,128 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine configs and factory.
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
backend requires registering its config subclass and dispatching it in
:func:`create_inference_engine`.
"""
from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
from threading import Event
import draccus
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.processor import PolicyProcessorPipeline
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceEngine
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configs
# ---------------------------------------------------------------------------
@dataclass
class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC):
"""Abstract base for inference backend configuration.
Use ``--inference.type=<name>`` on the CLI to select a backend.
"""
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@InferenceEngineConfig.register_subclass("sync")
@dataclass
class SyncInferenceConfig(InferenceEngineConfig):
"""Inline synchronous inference (one policy call per control tick)."""
@InferenceEngineConfig.register_subclass("rtc")
@dataclass
class RTCInferenceConfig(InferenceEngineConfig):
"""Real-Time Chunking: async policy inference in a background thread."""
# Eagerly constructed so draccus exposes nested fields directly on the CLI
# (e.g. ``--inference.rtc.execution_horizon=...``).
rtc: RTCConfig = field(default_factory=RTCConfig)
queue_threshold: int = 30
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def create_inference_engine(
config: InferenceEngineConfig,
*,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
robot_wrapper: ThreadSafeRobot,
hw_features: dict,
dataset_features: dict,
ordered_action_keys: list[str],
task: str,
fps: float,
device: str | None,
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
shutdown_event: Event | None = None,
) -> InferenceEngine:
"""Instantiate the appropriate inference engine from a config object."""
logger.info("Creating inference engine: %s", config.type)
if isinstance(config, SyncInferenceConfig):
return SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task=task,
device=device,
robot_type=robot_wrapper.robot_type,
)
if isinstance(config, RTCInferenceConfig):
return RTCInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
robot_wrapper=robot_wrapper,
rtc_config=config.rtc,
hw_features=hw_features,
task=task,
fps=fps,
device=device,
use_torch_compile=use_torch_compile,
compile_warmup_inferences=compile_warmup_inferences,
rtc_queue_threshold=config.queue_threshold,
shutdown_event=shutdown_event,
)
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")

View File

@@ -0,0 +1,360 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Real-Time Chunking inference engine.
A background thread produces action chunks asynchronously via
:meth:`policy.predict_action_chunk`. The main control loop polls
``get_action`` for the next ready action; observations flow the other
way via ``notify_observation``.
"""
from __future__ import annotations
import logging
import math
import time
import traceback
from threading import Event, Lock, Thread
from typing import Any
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionQueue, LatencyTracker, reanchor_relative_rtc_prefix
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import (
NormalizerProcessorStep,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
)
from lerobot.utils.feature_utils import build_dataset_frame
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceEngine
logger = logging.getLogger(__name__)
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
_RTC_IDLE_SLEEP_S: float = 0.01
# Backoff between transient inference errors (per consecutive failure).
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
# Consecutive transient errors tolerated before giving up and propagating shutdown.
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
# Hard timeout for joining the RTC thread on stop().
_RTC_JOIN_TIMEOUT_S: float = 3.0
# ---------------------------------------------------------------------------
# RTC helpers
# ---------------------------------------------------------------------------
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
if prev_actions.ndim != 2:
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
steps, action_dim = prev_actions.shape
if steps == target_steps:
return prev_actions
if steps > target_steps:
return prev_actions[:target_steps]
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
padded[:steps] = prev_actions
return padded
# ---------------------------------------------------------------------------
# RTCInferenceEngine
# ---------------------------------------------------------------------------
class RTCInferenceEngine(InferenceEngine):
"""Async RTC inference: a background thread produces action chunks.
``get_action`` pops the next action from the shared queue (or
returns ``None`` if the queue is empty). The main loop should call
``notify_observation`` every tick and ``pause``/``resume`` around
human-intervention phases.
"""
def __init__(
self,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
robot_wrapper: ThreadSafeRobot,
rtc_config: RTCConfig,
hw_features: dict,
task: str,
fps: float,
device: str | None,
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
rtc_queue_threshold: int = 30,
shutdown_event: Event | None = None,
) -> None:
self._policy = policy
self._preprocessor = preprocessor
self._postprocessor = postprocessor
self._robot = robot_wrapper
self._rtc_config = rtc_config
self._hw_features = hw_features
self._task = task
self._fps = fps
self._device = device or "cpu"
self._use_torch_compile = use_torch_compile
self._compile_warmup_inferences = compile_warmup_inferences
self._rtc_queue_threshold = rtc_queue_threshold
self._action_queue: ActionQueue | None = None
self._obs_holder: dict[str, Any] = {}
self._obs_lock = Lock()
self._policy_active = Event()
self._compile_warmup_done = Event()
self._shutdown_event = Event()
self._rtc_error = Event()
self._global_shutdown_event = shutdown_event
self._rtc_thread: Thread | None = None
if not self._use_torch_compile:
self._compile_warmup_done.set()
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
else:
logger.info(
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
compile_warmup_inferences,
)
# Processor introspection for relative-action re-anchoring.
self._relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
self._normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if self._relative_step is not None:
if self._relative_step.action_names is None:
cfg_names = getattr(policy.config, "action_feature_names", None)
if cfg_names:
self._relative_step.action_names = list(cfg_names)
else:
self._relative_step.action_names = [
k for k in robot_wrapper.action_features if k.endswith(".pos")
]
logger.info("Relative actions enabled: RTC prefix will be re-anchored")
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
@property
def ready(self) -> bool:
"""True once torch.compile warmup is complete (or immediately if compile is disabled)."""
return self._compile_warmup_done.is_set()
@property
def failed(self) -> bool:
"""True if the RTC background thread exited due to an unrecoverable error."""
return self._rtc_error.is_set()
@property
def action_queue(self) -> ActionQueue | None:
"""The shared action queue between the RTC thread and the main loop."""
return self._action_queue
def start(self) -> None:
"""Launch the RTC background thread."""
self._action_queue = ActionQueue(self._rtc_config)
self._obs_holder = {
"obs": None,
"robot_type": self._robot.robot_type,
}
self._shutdown_event.clear()
self._rtc_thread = Thread(
target=self._rtc_loop,
daemon=True,
name="RTCInference",
)
self._rtc_thread.start()
logger.info("RTC inference thread started")
def stop(self) -> None:
"""Signal the RTC thread to stop and wait for it."""
logger.info("Stopping RTC inference thread...")
self._shutdown_event.set()
self._policy_active.clear()
if self._rtc_thread is not None and self._rtc_thread.is_alive():
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
if self._rtc_thread.is_alive():
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
else:
logger.info("RTC inference thread stopped")
self._rtc_thread = None
def pause(self) -> None:
"""Pause the RTC background thread."""
logger.info("Pausing RTC inference thread")
self._policy_active.clear()
def resume(self) -> None:
"""Resume the RTC background thread."""
logger.info("Resuming RTC inference thread")
self._policy_active.set()
def reset(self) -> None:
"""Reset the policy, processors, and action queue."""
logger.info("Resetting RTC inference state (policy + processors + queue)")
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
if self._action_queue is not None:
self._action_queue.clear()
# ------------------------------------------------------------------
# Action production (called from main thread)
# ------------------------------------------------------------------
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Pop the next action from the RTC queue (ignores ``obs_frame``)."""
if self._action_queue is None:
return None
return self._action_queue.get()
def notify_observation(self, obs: dict) -> None:
"""Publish the latest observation for the RTC thread to consume."""
with self._obs_lock:
self._obs_holder["obs"] = obs
# ------------------------------------------------------------------
# RTC: background inference thread
# ------------------------------------------------------------------
def _rtc_loop(self) -> None:
"""Background thread that generates action chunks via RTC."""
try:
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / self._fps
policy_device = torch.device(self._device)
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
inference_count = 0
consecutive_errors = 0
while not self._shutdown_event.is_set():
if not self._policy_active.is_set():
time.sleep(_RTC_IDLE_SLEEP_S)
continue
queue = self._action_queue
with self._obs_lock:
obs = self._obs_holder.get("obs")
if queue is None or obs is None:
time.sleep(_RTC_IDLE_SLEEP_S)
continue
if queue.qsize() <= self._rtc_queue_threshold:
try:
current_time = time.perf_counter()
idx_before = queue.get_action_index()
prev_actions = queue.get_left_over()
latency = latency_tracker.max()
delay = math.ceil(latency / time_per_chunk) if latency else 0
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
obs_batch = prepare_observation_for_inference(
obs_batch, policy_device, self._task, self._robot.robot_type
)
obs_batch["task"] = [self._task]
preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None:
# Rebase against the raw cached state so the leftover tail stays in
# the training-time coordinate frame.
raw_state = self._relative_step.get_cached_state()
if raw_state is not None:
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:
prev_actions = reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_abs,
current_state=raw_state,
relative_step=self._relative_step,
normalizer_step=self._normalizer_step,
policy_device=policy_device,
)
if prev_actions is not None:
prev_actions = _normalize_prev_actions_length(
prev_actions, target_steps=self._rtc_config.execution_horizon
)
actions = self._policy.predict_action_chunk(
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
)
original = actions.squeeze(0).clone()
processed = self._postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
inference_count += 1
consecutive_errors = 0
is_warmup = self._use_torch_compile and inference_count <= warmup_required
if is_warmup:
latency_tracker.reset()
else:
latency_tracker.add(new_latency)
queue.merge(original, processed, new_delay, idx_before)
if (
is_warmup
and inference_count >= warmup_required
and not self._compile_warmup_done.is_set()
):
self._compile_warmup_done.set()
logger.info("Compile warmup complete (%d inferences)", inference_count)
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
except Exception as e:
consecutive_errors += 1
logger.error(
"RTC inference error (%d/%d): %s",
consecutive_errors,
_RTC_MAX_CONSECUTIVE_ERRORS,
e,
)
logger.debug(traceback.format_exc())
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
# Persistent failure: stop retrying and propagate shutdown.
raise
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
else:
time.sleep(_RTC_IDLE_SLEEP_S)
except Exception as e:
logger.error("Fatal error in RTC thread: %s", e)
logger.error(traceback.format_exc())
self._rtc_error.set()
# Unblock any warmup waiters so the main loop doesn't spin forever
self._compile_warmup_done.set()
# Signal the top-level shutdown so strategies exit their control loops
if self._global_shutdown_event is not None:
self._global_shutdown_event.set()

Some files were not shown because too many files have changed in this diff Show More