Compare commits

...

175 Commits

Author SHA1 Message Date
Pepijn
86e7302e10 Merge branch 'feat/mirror' into openarms_wallx_rebased_3 2026-02-24 11:53:01 +01:00
Pepijn
0394fae446 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-22 16:12:10 +01:00
Pepijn
602b8e66a6 fix multi gpu processor bug 2026-02-22 16:11:52 +01:00
Pepijn
ab4dce6fed revert 2026-02-21 18:48:46 +01:00
Pepijn
40f4386e4a nccl 2026-02-21 18:44:35 +01:00
Pepijn
87a91b4b08 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 18:19:51 +01:00
Pepijn
fadb900c36 compute before dist 2026-02-21 18:19:12 +01:00
Pepijn
de0663226a max 1m frames 2026-02-21 17:44:12 +01:00
Pepijn
0ca9d66cae max 1m frames 2026-02-21 17:43:58 +01:00
Pepijn
2222f25da3 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 17:28:35 +01:00
Pepijn
acae8417aa fix 2026-02-21 17:28:26 +01:00
Pepijn
2697f65cf6 stats for entire dataset 2026-02-21 17:15:55 +01:00
Pepijn
74f42f218e stats for entire dataset 2026-02-21 17:15:45 +01:00
Pepijn
ca9d49e305 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 17:12:52 +01:00
Pepijn
6705876d47 use quantiles 2026-02-21 17:12:43 +01:00
Pepijn
aadbd27675 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 08:48:39 +01:00
Pepijn
5221647b5e fix 2026-02-21 08:48:08 +01:00
Pepijn
9c981300dd stats per chunck 2026-02-21 08:37:38 +01:00
Pepijn
f5b27aad1b stats per chunck 2026-02-21 08:37:19 +01:00
Pepijn
75f1285507 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 08:02:39 +01:00
Pepijn
33cedc2f71 sample 1m 2026-02-21 08:02:25 +01:00
Pepijn
aa32e6c4ab Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 07:52:10 +01:00
Pepijn
f906270ec4 load from parquet 2026-02-21 07:51:57 +01:00
Pepijn
733b6d84db Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 07:42:00 +01:00
Pepijn
8abc9037a3 sample 100k 2026-02-21 07:41:42 +01:00
Pepijn
e4d4ac0bda Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 00:03:37 +01:00
Pepijn
e79b2a439b calulate chunk based stats 2026-02-21 00:03:21 +01:00
Pepijn
f9ae78ca74 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-20 23:04:36 +01:00
Pepijn
e1ced538e3 only recompute state for stats 2026-02-20 23:04:20 +01:00
Pepijn
2a98602ad6 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-20 22:54:46 +01:00
Pepijn
a2f5b3571e normalzie after delta conversion 2026-02-20 22:54:29 +01:00
Pepijn
cecf2eff4f Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-20 17:59:19 +01:00
Pepijn
7e6b598a51 add recomputation of stats and option to compute delta stats 2026-02-20 17:59:06 +01:00
Pepijn
4fa41ba806 formatting 2026-02-13 17:46:18 +01:00
Pepijn
1de2b87a92 Add option for pi family models to train with relative actions (relative to state) 2026-02-13 17:45:59 +01:00
Pepijn
6600b60e7f always use degrees (#2968) 2026-02-13 13:49:01 +01:00
Caroline Pascal
adebbcf090 fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949)
* fix(edit dataset operation): fixing dataset tools CLI operation type specification

* test(edit dataset operation): adding tests for dataset tools operation type specification

* chore(format): running pre-commit

* chore(backward compatibility): adding a type property in OperationConfig for backward compatibility

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-02-12 18:56:04 +01:00
taken-yjyoon
3615160d89 fix(typo): Fixing wrong argparse examples in the comments (using 'True' not 'true') (#1040)
Co-authored-by: juni <>
2026-02-12 18:13:51 +01:00
Steven Palma
fc8a388a25 feat(cameras): make backend configurable to the CLI (#2945)
* feat(cameras): make backend configurable to the CLI

* chore(cameras): address feedback

* feat(Enum error messages): adding better instanciation error messages for Enum classes

* chore(Enum error messages): propagating Enum error messages to all camera classes

* chore(comments): removing superfluous comments

* chore(format): applying ruff checks

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
2026-02-11 13:57:25 +01:00
Steven Palma
3c84d271d5 fix(motors): use decorator to fix precommit (#2951) 2026-02-10 18:40:50 +01:00
Steven Palma
1ba3975020 chore: use is_connected decorators (#2948)
* chore: use is_connected decorators

* chore(robots): add is_connected to bi setups too
2026-02-10 17:49:30 +01:00
Steven Palma
35363c5798 chore(linter): ensure motors module passes MyPy type checks (#2939)
* fix: ensure motors module passes MyPy type checks

This commit fixes 62 mypy type errors in the motors module by:

- Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead,
  GroupSyncWrite) to use class-level attribute declarations instead of
  __init__ body declarations
- Adding missing `broadcastPing` method to PacketHandler Protocol
- Fixing return type annotations (e.g., `_get_motor_model` returns str, not int)
- Fixing parameter types to use `Sequence` for covariant list parameters
- Fixing `Mapping` for covariant dict value types in `_normalize`
- Updating method signatures to be consistent across parent and child classes
  (disable_torque, enable_torque, _get_half_turn_homings)
- Adding explicit `int()` casts for MotorCalibration arguments
- Adding explicit `return None` for functions returning Optional types
- Adding type annotations for variables like `data_list: dict[int, int]`
- Using `# type: ignore[method-assign]` for intentional monkeypatch
- Fixing variable references (using `self.groups` instead of `groups`)

Fixes #1723

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* chore(style): pre-commit after main merge

* chore(linter): solve comments

* chore(linter): apply pre-commit fixes to damiao

* chore(linter): more fixes to damiao

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-10 17:35:39 +01:00
whats2000
778db19a17 [Bug Fix] fix(ci): prevent runner group error on fork pushes (#2911)
* fix(ci): prevent runner group error on fork pushes

Add repository check to unbound_deps_tests workflow to ensure
aws-general-8-plus runner group is only used on main repository,
preventing 'Required runner group not found' errors on forks.

* fix(ci): use gating job to prevent runner allocation on forks

The previous approach failed because GitHub evaluates runs-on before if conditions.
Now using a check-repo job that runs on ubuntu-latest first, and all jobs with
special runners depend on it and check its output before being scheduled.

* fix(ci): add gating job to full_tests to prevent runner allocation on forks

Apply the same gating pattern used in unbound_deps_tests to full_tests.yml
to prevent GitHub from trying to allocate custom runners when workflows
run on forks. The check-repo job runs first on ubuntu-latest and all jobs
with custom runners depend on it and check its output.

* fix(ci): add repository check to unbound_deps_tests workflow

Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker job to prevent runner group access errors on forks, matching the pattern used in nightly.yml

* fix(ci): add repository check to full_tests workflow

Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker and gpu-tests jobs to prevent runner group access errors on forks

* refactor(ci): remove redundant check from gpu-tests job

gpu-tests depends on build-and-push-docker via needs, so it will automatically skip when the parent job is skipped

* refactor(ci): remove unnecessary fork check from full-tests job

full-tests runs on ubuntu-latest which is available to all forks, no need to restrict it

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:21:40 +01:00
Jai Kumaar Ratadia
d2d01399d6 docs: clarify installation steps are sequential, not optional (#2925)
* docs: clarify installation steps are sequential, not optional

Add intro paragraph noting conda is one path (not the only one) and
number the three sections as steps so readers understand miniforge and
environment setup are prerequisites, not independent choices.

* Update installation guide link for LeRobot

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

* Fix link formatting in installation guide again

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

---------

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:18:32 +01:00
Aoqun Jin
5eba4ce6f4 Change LIBERO init_state_id when reset. (#2899)
* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* pre-commit run

---------

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 16:39:17 +03:00
Stepan Feduniak
cca0296cd6 fix(pipeline): use FeatureType for STATE features in Libero processor (#2888)
* fix the types

* pre-commit

---------

Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:55:11 +03:00
Steven Palma
489cb7b6b9 fix(scripts): correct can import check (#2937) 2026-02-09 16:58:32 +01:00
Reece O'Mahoney
e14bdf57d0 Convert tensors to scalars (#2903)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-09 14:46:12 +01:00
Pepijn
3ec7c25e7d speedup stats and encoding 2026-02-06 11:26:27 +01:00
Reece O'Mahoney
97e7e0f9ed feat(datasets): improve image transform support (#2885)
* improve image transform support

* add tests

* Add stricter transform check and extra test

* improve subclass check
2026-02-05 15:39:58 +01:00
Pepijn
e3c511db67 add push to hub 2026-02-05 09:25:49 +01:00
Pepijn
aed4130d39 add swap wrist camera's 2026-02-04 22:47:28 +01:00
Pepijn
d26349c692 add push to hub 2026-02-04 19:17:40 +01:00
Pepijn
a9bce4732b fix setting metadata 2026-02-04 19:04:43 +01:00
Pepijn
86d69e3c1d add mirroring 2026-02-04 18:56:51 +01:00
jwang078
0f39248445 Small docstring fix in diffusion configuration (#2847) 2026-02-03 19:19:00 +01:00
Iori Yanokura
a6370dd783 fix(wandb): truncate init tags to 64-character limit (#995) 2026-02-03 14:17:04 +01:00
Pepijn
2d8ac028f9 remove async stuff 2026-02-03 11:01:32 +01:00
Pepijn
ec1de9c9e3 encode while recording 2026-02-03 10:41:11 +01:00
Pepijn
1ea040fe8c reduce memoery load and move to video folder 2026-02-03 10:29:45 +01:00
Pepijn
c028ae3a44 Async encoding 2026-02-03 08:50:34 +01:00
Michel Aractingi
14a15f90e7 Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873)
* fix(RL) add missing config arguments

* respond to copilot review

* fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic.

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
2026-02-02 22:14:03 +01:00
Hirokazu Ishida
9c24a09665 docs: update document in response to Simplify configs PR (#1596)
* docs: update document input/output_shapes -> input/output_features

* fix inconsistent quote (suggested by copilot reviewer)

* docs: shapes => PolicyFeature

* docs: relfect normalization_mapping and remove outdated
2026-02-02 20:05:58 +01:00
Jade Choghari
b18cef2e26 feat(dataset): add subtask support (#2860)
* add subtask

* remove folder

* add docs

* update doc

* add testing

* update test

* update constant naming + doc

* more docs
2026-01-30 19:29:37 +01:00
Caroline Pascal
5c6182176f fix(find zmq): adding a clearer not implemented warning for the ZMQ find_cameras method (#2879)
Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
2026-01-30 16:58:13 +01:00
Caroline Pascal
55c0471db9 docs(cameras): revising and improving docs on cameras (#2878)
* docs(cameras): revising and improving docs on cameras

* resolving copilot comments
2026-01-30 16:57:56 +01:00
Michel Aractingi
ec04b7ce3a Feat(dataset_tools.py) Add modify tasks tool (#2875)
* feat(datasets): add modify_tasks function for in-place task editing

Add a new utility function to modify tasks in LeRobotDataset in-place.
This allows users to:
- Set a single task for all episodes
- Set specific tasks for individual episodes
- Combine a default task with per-episode overrides

* feat(edit-dataset): add CLI support for modify_tasks operation

Integrate the modify_tasks function into lerobot_edit_dataset CLI.
Users can now modify dataset tasks via command line:
Supports setting a default task, per-episode tasks, or both combined.

* test(datasets): add tests for modify_tasks function

Add comprehensive test coverage for the modify_tasks utility:
- Single task for all episodes
- Episode-specific task assignment
- Default task with per-episode overrides
- Error handling for missing/invalid arguments
- Verification of task_index correctness
- In-place modification behavior
- Metadata preservation

* respond to copilot review
2026-01-30 13:19:42 +01:00
Michel Aractingi
04cbf669cf fix(sac): make temperature a property to fix checkpoint resume bug (#2877)
* fix(sac): make temperature a property to fix checkpoint resume bug

Temperature was stored as a plain float and not restored after loading
a checkpoint, causing incorrect loss computations until update_temperature()
was called. Changed to a property that always computes from log_alpha,
ensuring correct behavior after checkpoint loading.

* simplify docstrings
2026-01-30 12:23:22 +01:00
Pepijn
2598dbc31a Merge branch 'feat/training_time_rtc' into openarms_wallx_rebased_3 2026-01-29 11:17:15 +01:00
Steven Palma
3409ef0dc2 refactor(cameras): cameras API extension (#2808)
* feat(cameras): add new read_latest() method

* fix(cameras): fix threading bug + clear state

* refactor(cameras): multiple improvements

* feat(camera): add context manager to camera base class

* chore(camera): slight modifications to opencv

* test(cameras): update opencv tests according to the changes

* refactor(cameras): reflect desing changes to realsense + deal with depth

* test(cameras): fix realsense tests accordingly to new changes

* refactor(cameras): update reachymini and zmq accordingly

* chore: wrap resource sensitive examples into a try/finally

* test(cameras): add test for new read_latest

* test(cameras): fix problem with image artifact in opencv tests

* test(cameras): fix test_read_latest_high_frequency expectations

* Apply suggestions from code review 1

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(cameras): address feedback

* feat(cameras): add max_age_ms check in read_latest

* test(cameras): fix read_latest tests

* chore(redundancies): removing redundancies in Reachy 2 camera class

* fix(warmup): replacing the arbitrary time.sleep in by an actual warmup in the RealSense camera class

* chore(format): formatting latest changes

* chore(warning): adding a "to be implemented" warning for read_latest() in Camera base class

* chore(warning): making read_latest() warning message shorter and clearer

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-01-29 11:07:47 +01:00
Steven Palma
4483184875 feat(robots): add bi manual openarm follower and leader (#2835)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* feat(robot): add openarm leader

Co-authored-by: Pepijn <pepijn@huggingface.co>

* feat(robot): add openarm follower

Co-authored-by: Pepijn <pepijn@huggingface.co>

* refactor(robot): remove mechanical compensations and double arm assumption + rename

* chore(robots): remove left arm references

* refactor(teleop): multiple improvements to leader

* refactor(teleop): multiple improvements to leader

* feat(robots): add open arm to util CLI

* chore(robot): add alias openarm

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* fix(robots): open arm mirrored config for joint limits

* chore(motors): update position_kd gain values

* chore(robots): set to 0 if openarm is calibrated at connect time

* chore(robots): remove macos in open arm as can doesn't support it

* chore(robots): update for motor_type_str in Motor class

* chore(robots): no default value for can port in open arms

* feat(robots): add bi manual openarm follower and leader

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

* remove comment

* Add openarms docs

* format

* update purchase link

* can to none if nit availabl;e

* add canfd option in bus

* make handshake logic similar to lerobot-can

* type hint

* type check

* add temp teleop test

* remove script

* mock class

* mock class

* ignore linter

* pre-commit

* Add command for bimanual openarm

* fix import

* fix import leader

* fix import draccus

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-01-28 17:25:57 +01:00
Martino Russi
149628dfd5 add g1 teleoperation (#2791)
* add gravity compensation

* add g1 teleoperation

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2026-01-28 15:17:38 +01:00
Steven Palma
bf337e716d feat(robots): add OpenArm robot & teleoperator (#2795)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* feat(robot): add openarm leader

Co-authored-by: Pepijn <pepijn@huggingface.co>

* feat(robot): add openarm follower

Co-authored-by: Pepijn <pepijn@huggingface.co>

* refactor(robot): remove mechanical compensations and double arm assumption + rename

* chore(robots): remove left arm references

* refactor(teleop): multiple improvements to leader

* refactor(teleop): multiple improvements to leader

* feat(robots): add open arm to util CLI

* chore(robot): add alias openarm

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* fix(robots): open arm mirrored config for joint limits

* chore(motors): update position_kd gain values

* chore(robots): set to 0 if openarm is calibrated at connect time

* chore(robots): remove macos in open arm as can doesn't support it

* chore(robots): update for motor_type_str in Motor class

* chore(robots): no default value for can port in open arms

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

* remove comment

* Add openarms docs

* format

* update purchase link

* can to none if nit availabl;e

* add canfd option in bus

* make handshake logic similar to lerobot-can

* type hint

* type check

* add temp teleop test

* remove script

* mock class

* ignore linter

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-01-28 14:28:51 +01:00
Michel Aractingi
736b43f3cf Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861)
* Fix aggeregation of datasets when subdatasets are already a result of a previous merge

* docstring

* respond to copilot review + add regression test

* Remove unnecessary int conversion for indicies
2026-01-28 13:31:27 +01:00
Pepijn
bc68651815 add command 2026-01-16 16:43:45 +01:00
Pepijn
d1f50babaa fix rac data collection with rtc by disabling compile 2026-01-15 17:06:58 +01:00
Pepijn
3316301693 debug rtc 2026-01-09 16:58:57 +01:00
Pepijn
feedababd2 debug 2026-01-09 16:54:11 +01:00
Pepijn
480ee3299f log 2026-01-09 16:50:44 +01:00
Pepijn
2d1fb0f508 refactor 2026-01-09 16:41:59 +01:00
Pepijn
b1a55b0666 by default dont use rtc 2026-01-09 16:26:54 +01:00
Pepijn
24af996f82 add logging 2026-01-09 16:10:32 +01:00
Pepijn
8d7eec79c8 f 2026-01-09 16:06:02 +01:00
Pepijn
ccced0c9fc f 2026-01-09 15:58:37 +01:00
Pepijn
4166eeb7da have only rtc thread read obs and expose it 2026-01-09 15:48:49 +01:00
Pepijn
1f93a74d8c fix queue 2026-01-09 14:00:06 +01:00
Pepijn
b16e2f25f7 remove move to zero due to potential race condition 2026-01-09 13:56:16 +01:00
Pepijn
9cc841c674 wait for first actions 2026-01-09 13:45:06 +01:00
Pepijn
63c28ea395 add cmd arg 2026-01-09 13:38:33 +01:00
Pepijn
98c33a4748 Add RaC with RTC 2026-01-09 13:26:25 +01:00
Pepijn
4428248a01 Increase d 2026-01-09 13:17:18 +01:00
Pepijn
7d6f113072 fix at 2x actual freq 2026-01-09 13:03:29 +01:00
Pepijn
7ac05c838d add interpolation option 2026-01-09 12:56:43 +01:00
Pepijn
c85f1692d6 in place 2026-01-03 22:12:22 +01:00
Pepijn
9fd329713a modift in place 2026-01-03 22:11:11 +01:00
Pepijn
97d068e5a2 rename to fold 2026-01-03 21:59:11 +01:00
Pepijn
e5bea36387 add unify task 2026-01-03 21:52:19 +01:00
Pepijn
cf1d8c3d5b stop policy when we dont teleop yet 2026-01-02 13:12:22 +01:00
Pepijn
464b65cfb0 wait for start button before teleop 2026-01-02 13:05:00 +01:00
Pepijn
90145426b4 add gripper in send feedback 2026-01-02 11:22:45 +01:00
Pepijn
c76bc4cdea Move robot to zero before begin episode 2026-01-02 10:52:48 +01:00
Pepijn
20f0381f81 wait for takeover press 2026-01-02 10:18:59 +01:00
Pepijn
a447c652cb change pedal flow 2026-01-02 09:53:40 +01:00
Pepijn
8277dbf0dc add foot pedal support 2026-01-02 09:36:36 +01:00
Pepijn
eb0918249d keep teleop active in reset 2026-01-02 09:21:15 +01:00
Pepijn
640a7889fc use same flip for send_feedback 2026-01-01 16:49:04 +01:00
Pepijn
03c6ee5f9a fix grippers 2026-01-01 16:40:53 +01:00
Pepijn
dfd229ae4f fix direction and encoding 2026-01-01 16:37:11 +01:00
Pepijn
aba42c805f some changes to smooth 2025-12-31 15:16:23 +01:00
Pepijn
8b6b41f8dc reverse 2025-12-31 15:11:00 +01:00
Pepijn
1771da222b openarms mini swap joints 6 and 7 2025-12-31 15:08:06 +01:00
Pepijn
0514616c87 dont move teleop when not pause pressed 2025-12-31 12:33:40 +01:00
Pepijn
f15872293d Only move teleop after space press 2025-12-31 12:24:43 +01:00
Pepijn
a97255e3d1 use robot_action 2025-12-30 12:04:30 +01:00
Pepijn
1716d599c1 only use position in dataset 2025-12-30 12:01:26 +01:00
Pepijn
c07ab7e1fa policy path can be none 2025-12-30 11:14:21 +01:00
Pepijn
5ba9fbd9ca fix processor step 2025-12-30 11:09:16 +01:00
Pepijn
38b814f3d4 add feedback to openarms mini 2025-12-30 10:48:55 +01:00
Pepijn
48a963793b Add rac openarms 2025-12-30 10:41:32 +01:00
Pepijn
9833b84bf8 merge rac 2025-12-30 10:37:48 +01:00
Pepijn
27eeff7535 Add RaC doc and example 2025-12-30 09:57:40 +01:00
Michel Aractingi
202a493c14 missing imports processor wallx 2025-12-17 18:25:21 +01:00
Pepijn
eadd4c0856 only export WallXConfig from wall_x package to avoid peft import in CI 2025-12-17 18:06:42 +01:00
Pepijn
3434a5d5df pre-commit 2025-12-17 18:06:42 +01:00
Pepijn
1ba51a6d02 fix: peft test import 2025-12-17 18:06:41 +01:00
Pepijn
c62ca6c5d2 fix: add uv conflicts for wallx transformers version 2025-12-17 18:06:41 +01:00
Pepijn
4831195310 fix: exclude wallx extra properly in CI workflows 2025-12-17 18:06:41 +01:00
Pepijn
c514d9ffe2 fix precommit issues 2025-12-17 18:06:40 +01:00
Pepijn
9ae4477356 fix ci 2025-12-17 18:06:40 +01:00
Geoffrey19
0e545e5177 remove lerobot[wallx] 2025-12-17 18:06:40 +01:00
Geoffrey19
a0c9a7d85d fix pre-commit errors 2025-12-17 18:06:39 +01:00
Geoffrey19
9ce6dd9e25 add some small modifications 2025-12-17 18:06:39 +01:00
Geoffrey19
51bd288f1a fix bug for inference 2025-12-17 18:06:39 +01:00
Geoffrey19
fc6262e23d remove flash-attn requirement && fix bug in inference and fast mode 2025-12-17 18:06:38 +01:00
Geoffrey19
d2b16afb12 update 2025-12-17 18:06:38 +01:00
Geoffrey19
a754c86f64 add wallx dependencies 2025-12-17 18:06:37 +01:00
Geoffrey19
76e6dc1ba1 fixed dtype bugs 2025-12-17 18:06:37 +01:00
Geoffrey19
d10d3ef251 reduce to least config and params & pass lerobot basic test 2025-12-17 18:06:37 +01:00
Geoffrey19
feebca050a update the policy methods 2025-12-17 18:06:36 +01:00
Geoffrey19
a8e7a2967c incorporate wallx model into lerobot 2025-12-17 18:06:36 +01:00
Geoffrey19
2cf509795e fix bugs in flow 2025-12-17 18:06:36 +01:00
vincentchen
d3846b0beb support wallx 2025-12-17 18:06:35 +01:00
Michel Aractingi
08d2ed8015 lerobot dataset fix 2025-12-17 16:46:43 +01:00
Michel Aractingi
4bcd14b8de add evaluate_with_rtc script 2025-12-17 16:46:43 +01:00
Michel Aractingi
c34935090d integrate delete button openarm UI (#2535)
* add visualize_dataset call from `lerobot_dataset_viz` in web record server

* add delete button

* fixes

* remove viz

* unused import
2025-12-17 16:46:43 +01:00
CarolinePascal
9cfd56587e fix(num processes) 2025-12-17 16:46:43 +01:00
Caroline Pascal
ff8584a025 fix(os version)
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2025-12-17 16:46:43 +01:00
Caroline Pascal
6bc1e5186a fix(import os)
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2025-12-17 16:46:43 +01:00
Caroline Pascal
69dc8165ae fix(max workers)
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2025-12-17 16:46:42 +01:00
CarolinePascal
021bca2ad9 feat(multi-processes): adding support for multiprocess encoding 2025-12-17 16:46:42 +01:00
CarolinePascal
4e0ee0d643 feat(preset): adding encoding preset 2025-12-17 16:46:42 +01:00
croissant
0a8aa85871 ruse video datasets 2025-12-17 16:46:42 +01:00
croissant
76ddd8b948 use image datasets and change ui 2025-12-17 16:46:42 +01:00
croissant
bf08733068 frontend set correct port openarms mini 2025-12-17 16:46:42 +01:00
croissant
e38f56c071 add default mini arms 2025-12-17 16:46:41 +01:00
croissant
19fe69dac0 add improv openarm mini 2025-12-17 16:46:41 +01:00
pepijn kooijmans
14319ee608 add openarms mini 2025-12-17 16:46:41 +01:00
croissant
9b04fd25b6 cam res 2025-12-17 16:46:41 +01:00
Pepijn
40e98ba690 fix calibration of gripper and add max clip positions for openarm for safety 2025-12-17 16:46:41 +01:00
pepijn kooijmans
894d65d58a add openarms to setup motors 2025-12-17 16:46:41 +01:00
Pepijn
f58d508df2 cleanuo 2025-12-17 16:46:40 +01:00
Pepijn
e22b909e7c Add mini openarms to test 2025-12-17 16:46:40 +01:00
croissant
09f1673cbf add longer timeout 2025-12-17 16:46:40 +01:00
croissant
4744f99990 add timing debugging, foot pedal and eval script 2025-12-17 16:46:40 +01:00
croissant
01c1735739 add disable torque 2025-12-17 16:46:40 +01:00
croissant
6808a42455 add pid ramp 2025-12-17 16:46:40 +01:00
croissant
fff719cb4f add web interface example 2025-12-17 16:46:39 +01:00
croissant
e2c00f6ed8 speedup 2025-12-17 16:46:39 +01:00
croissant
0f90db23c5 add full bimanual gravity comp 2025-12-17 16:46:39 +01:00
Michel Aractingi
96b192f2ae Add gravity compensation to the openarms teleoperation (#2352)
* adding first attempt at gcompensation to open arms

* add teleop with gravity compensation script
2025-12-17 16:46:39 +01:00
Pepijn
ecdc34a699 faster canbus 2025-12-17 16:46:39 +01:00
croissant
fa6a2fb9b7 pos teleop 2025-12-17 16:46:39 +01:00
Pepijn
b011643dc9 add tests and debug 2025-12-17 16:46:38 +01:00
Pepijn
30c10c1c6e Add damiao motors and open arm robot 2025-12-17 16:46:38 +01:00
Pepijn
56e2360072 add damiao 2025-12-17 16:46:38 +01:00
189 changed files with 22396 additions and 1409 deletions

View File

@@ -101,9 +101,11 @@ jobs:
runs-on:
group: aws-general-8-plus
if: |
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
github.repository == 'huggingface/lerobot' && (
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
)
outputs:
image_tag: ${{ steps.set_tag.outputs.image_tag }}
env:

View File

@@ -91,6 +91,7 @@ jobs:
name: Build and Push Docker
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
env:

140
STREAMING_ENCODING_PR.md Normal file
View File

@@ -0,0 +1,140 @@
# Streaming Video Encoding — Encode on the fly during recording
## Problem
After each episode, `save_episode()` blocks for **~79 seconds** on a 3-camera setup (3197 frames, 107s episode):
| Step | Time |
|------|------|
| Write 9591 PNGs to disk | ~19s |
| Read PNGs back → compute image stats | ~15s |
| Read PNGs again → encode 3× AV1 videos → delete PNGs | ~44.5s |
| Save parquet + metadata | ~0.6s |
| **Total** | **~79s** |
The entire pipeline writes frames as temporary PNGs, reads them back twice (stats + encoding), then deletes them. This round-trip is the bottleneck.
## Architecture
### Before: sequential post-episode pipeline
```
Recording loop save_episode() — BLOCKS ~79s
┌─────────────┐ ┌──────────────────────────────────────────────────────────┐
│ 30fps loop │ │ │
│ │ frames │ frame_buffer ──► write PNGs ──► read PNGs ──► stats │
│ camera ─►───┼──► list │ (~19s) │ (~15s) │
│ teleop │ │ ▼ │
│ policy │ │ read PNGs ──► AV1 encode ──► delete PNGs │
│ │ │ (~44.5s) │
└──────┬───────┘ └──────────────────────────────────────────────────────────┘
│ │
▼ ▼
episode ends next episode
(~107s recording) (~79s blocked)
```
**Data path:** `frame → list → PNG disk → read → stats` + `PNG disk → read → encode → MP4 → delete PNGs`
### After: streaming pipeline (encodes during recording)
```
Recording loop (encoding happens HERE) save_episode() — ~0.5s
┌───────────────────────────────────────┐ ┌──────────────────┐
│ 30fps control loop │ │ │
│ │ │ flush encoders │
│ camera ──► frame ─┬─► queue ──► [T1] ├── AV1 ─┤ (already done) │
│ │ queue ──► [T2] ├── AV1 ─┤ ~0.16s │
│ │ queue ──► [T3] ├── AV1 ─┤ │
│ │ │ │ running stats │
│ └─► downsample ──► │─ stats ─┤ → finalize │
│ RunningQuantile │ │ ~0.01s │
│ teleop / policy (never blocked) │ │ │
└───────────────────────────────────────┘ │ save parquet │
│ ~0.36s │
[T1] [T2] [T3] = encoder threads └──────────────────┘
(one per camera, GIL released by PyAV)
```
**Data path:** `frame → queue → encode → MP4` (zero PNGs, zero re-reads)
## Stats computation changes
| | Before | After |
|---|---|---|
| **Method** | `compute_episode_stats()` reads all PNGs from disk, decodes them, computes min/max/mean/std/quantiles | `RunningQuantileStats` accumulates stats incrementally per frame during recording |
| **Input** | Full-resolution PNGs read back from disk | Downsampled frames (via `auto_downsample_height_width`, ~150×100px) directly from memory |
| **When** | After episode ends, inside `save_episode()` | During recording, inside `add_frame()` (~2ms per frame) |
| **Output** | `{mean, std, min, max, q01..q99}` shaped `(C,1,1)` in `[0,1]` | Identical shape and scale — `RunningQuantileStats.get_statistics()` → reshape `(C,1,1)` / 255 |
| **I/O** | Reads 9591 PNGs (~15s) | Zero disk I/O |
| **Numeric features** | Computed from episode buffer (unchanged) | Computed from episode buffer (unchanged) |
The running stats use the same `auto_downsample_height_width` function and produce the same statistical keys (`mean`, `std`, `min`, `max`, `count`, `q01`, `q10`, `q50`, `q90`, `q99`). Video features are excluded from the post-episode `compute_episode_stats()` call when streaming is active — only numeric features go through that path.
## Results
Tested on the same 3-camera setup (2028 frames, 67.6s episode):
| Step | Before | After | Speedup |
|------|--------|-------|---------|
| Frame writing (PNGs) | ~19s | **0s** | ∞ (eliminated) |
| Episode stats | ~15s | **0.01s** | 1500× |
| Video encoding | ~44.5s | **0.16s** | 278× |
| Parquet + meta | ~0.6s | **0.36s** | ~same |
| **Total `save_episode()`** | **~79s** | **0.55s** | **143×** |
The video encoding time drops to near-zero because most encoding already happened during recording. `finish_episode()` only flushes the last few buffered frames.
### Per-frame overhead during recording
| Operation | Time |
|-----------|------|
| `queue.put(frame)` (non-blocking) | ~0.01ms |
| `auto_downsample_height_width` | ~0.5ms |
| `RunningQuantileStats.update` | ~1ms |
| **Total per frame** | **~2ms** (well within 33ms budget at 30fps) |
## Usage
Streaming is **on by default**. Users on weaker PCs can disable it to fall back to the old post-episode pipeline:
```bash
# Default (streaming ON)
lerobot-record --dataset.repo_id=user/dataset ...
# Old behavior (streaming OFF)
lerobot-record --dataset.repo_id=user/dataset --dataset.streaming_encoding=false
```
For the RaC data collection script, set `streaming_encoding: false` in the dataset config.
## Files Changed
### `src/lerobot/datasets/video_utils.py`
- Added `StreamingVideoEncoder` — manages one `_CameraEncoder` thread per camera
- Added `_CameraEncoder` — daemon thread that reads frames from a queue and encodes with PyAV
- Non-blocking unbounded queue ensures the control loop is never delayed
### `src/lerobot/datasets/lerobot_dataset.py`
- `create()` / `start_streaming_encoder()`: new `streaming_encoding` parameter
- `add_frame()`: when streaming, feeds frames to encoder + accumulates running stats instead of writing PNGs
- `save_episode()`: when streaming, uses running stats and calls `finish_episode()` to get already-encoded video paths
- `clear_episode_buffer()`: cancels in-progress encoding on re-record
- `finalize()`: cleans up encoder on shutdown
- **Full backward compatibility**: when `streaming_encoding=False`, all existing code paths are unchanged
### `src/lerobot/scripts/lerobot_record.py`
- Added `streaming_encoding: bool = True` to `DatasetRecordConfig`
- Wired through to both `create()` and `resume` paths
### `examples/rac/rac_data_collection_openarms_rtc.py`
- Added `streaming_encoding: bool = True` to `RaCRTCDatasetConfig`
- Frames are added inline during the control loop (streaming) or buffered for post-loop writing (old path)
- Automatically detects mode and adjusts behavior
## Design Notes
- **Why threads, not processes?** PyAV/FFmpeg releases the GIL during encoding. Threads share memory (zero-copy frame passing), avoiding the serialization overhead of multiprocessing.
- **Why unbounded queue?** At 30fps production vs ~72fps encoding throughput, the queue stays near-empty. Even during brief encoder stalls, memory growth is bounded by episode length. The control loop must never block.
- **Why running stats?** Avoids the expensive read-back-from-disk step. `RunningQuantileStats` + `auto_downsample_height_width` compute identical statistics incrementally with ~2ms overhead per frame.
- **Backward compatible**: Setting `streaming_encoding=false` restores the original PNG → encode pipeline exactly. No behavior changes for existing users who don't opt in.

View File

@@ -7,8 +7,6 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: cameras
title: Cameras
- local: bring_your_own_policies
title: Bring Your Own Policies
- local: integrate_hardware
@@ -29,6 +27,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
title: "Datasets"
- sections:
- local: act
@@ -103,11 +103,17 @@
title: Earth Rover Mini
- local: omx
title: OMX
- local: openarm
title: OpenArm
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: torch_accelerators
title: PyTorch accelerators

View File

@@ -1,12 +1,22 @@
# Cameras
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
LeRobot offers multiple options for video capture:
### Finding your camera
| Class | Supported Cameras |
| ----------------- | ----------------------------------- |
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
| `ZMQCamera` | Network-connected cameras |
| `RealSenseCamera` | Intel RealSense (with depth) |
| `Reachy2Camera` | Reachy 2 robot cameras |
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
> [!TIP]
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
To find the camera indices of the cameras plugged into your system, run the following script:
### Find your camera
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
```bash
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
@@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
The output will look something like this if you have two cameras connected:
```
```bash
--- Detected Cameras ---
Camera #0:
Name: OpenCV Camera @ 0
@@ -33,13 +43,37 @@ Camera #0:
> [!WARNING]
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
## Use Cameras
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
Below are two examples, demonstrating how to work with the API.
## Use cameras
- **Asynchronous frame capture** using an OpenCV-based camera
### Frame access modes
All camera classes implement three access modes for capturing frames:
| Method | Behavior | Blocks? | Best For |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
### Usage examples
The following examples show how to use the camera API to configure and capture frames from different camera types.
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
- **Color and depth capture** using an Intel RealSense camera
> [!WARNING]
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
>
> ```python
> with OpenCVCamera(config) as camera:
> ...
> ```
>
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
<hfoptions id="shell_restart">
<hfoption id="Open CV Camera">
@@ -60,16 +94,30 @@ config = OpenCVCameraConfig(
)
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
camera = OpenCVCamera(config)
camera.connect()
with OpenCVCamera(config) as camera:
# Read a frame synchronously — blocks until hardware delivers a new frame
frame = camera.read()
print(f"read() call returned frame with shape:", frame.shape)
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"async_read call returned frame {i} with shape:", frame.shape)
except TimeoutError as e:
print(f"No frame received within timeout: {e}")
# Instantly return a frame - returns the most recent frame captured by the camera
try:
initial_frame = camera.read_latest(max_age_ms=1000)
for i in range(10):
frame = camera.read_latest(max_age_ms=1000)
print(f"read_latest call returned frame {i} with shape:", frame.shape)
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
except TimeoutError as e:
print(f"Frame too old: {e}")
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"Async frame {i} shape:", frame.shape)
finally:
camera.disconnect()
```
<!-- prettier-ignore-end -->
@@ -111,10 +159,10 @@ finally:
</hfoption>
</hfoptions>
## Use your phone
## Use your phone's camera
<hfoptions id="use phone">
<hfoption id="Mac">
<hfoption id="iPhone & macOS">
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
@@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
Your iPhone should be detected automatically when running the camera setup script in the next section.
</hfoption>
<hfoption id="Linux">
<hfoption id="OBS virtual camera">
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
<!-- prettier-ignore-start -->
```python
```bash
sudo apt install v4l2loopback-dkms v4l-utils
```
<!-- prettier-ignore-end -->
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Download and install [OBS Studio](https://obsproject.com)_.
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
5. _Start OBS Studio_.
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio
```
<!-- prettier-ignore-end -->
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
```
<!-- prettier-ignore-end -->
5. _Start OBS Studio_. Launch with:
<!-- prettier-ignore-start -->
```python
flatpak run com.obsproject.Studio
```
<!-- prettier-ignore-end -->
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
9. _Verify the virtual camera setup and resolution_.
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
```bash
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
```
You should see `VirtualCam` listed and resolution `640x480`.
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
<!-- prettier-ignore-start -->
```python
v4l2-ctl --list-devices
```
<!-- prettier-ignore-end -->
<details>
<summary><strong>Troubleshooting</strong></summary>
You should see an entry like:
> The virtual camera resolution is incorrect.
```
VirtualCam (platform:v4l2loopback-000):
/dev/video1
```
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
<!-- prettier-ignore-start -->
```python
v4l2-ctl -d /dev/video1 --get-fmt-video
```
<!-- prettier-ignore-end -->
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
You should see an entry like:
```
>>> Format Video Capture:
>>> Width/Height : 640/480
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
```
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
If everything is set up correctly, you can proceed with the rest of the tutorial.
</details>
</hfoption>
</hfoptions>
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.

View File

@@ -0,0 +1,278 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor.tokenizer_processor import TokenizerProcessor
from lerobot.processor.pipeline import ProcessorPipeline
# Create a tokenizer processor
tokenizer_processor = TokenizerProcessor(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation

View File

@@ -1,13 +1,15 @@
# Installation
## Install [`miniforge`](https://conda-forge.org/download/)
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Environment Setup
## Step 2: Environment Setup
Create a virtual environment with Python 3.10, using conda:
@@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
## Install LeRobot 🤗
## Step 3: Install LeRobot 🤗
### From Source

276
docs/source/openarm.mdx Normal file
View File

@@ -0,0 +1,276 @@
# OpenArm
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
## What's Unique?
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
## Platform Requirements
<Tip warning={true}>
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
does not have macOS drivers and has not been tested on Windows.
</Tip>
## Safety Guide
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
- **Emergency stop**: Know the location and operation of the emergency stop device
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
## Hardware Setup
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
- Bill of materials and sourcing
- 3D printing instructions
- Mechanical assembly
- Electrical wiring
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
## CAN Bus Setup
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
Quick setup:
```bash
# Setup CAN interfaces
lerobot-setup-can --mode=setup --interfaces=can0,can1
# Test motor communication
lerobot-setup-can --mode=test --interfaces=can0,can1
```
## Install LeRobot 🤗
Follow our [Installation Guide](./installation), then install the Damiao motor support:
```bash
pip install -e ".[damiao]"
```
## Usage
### Follower Arm (Robot)
<hfoptions id="follower">
<hfoption id="Command">
```bash
lerobot-calibrate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_openarm_follower
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
config = OpenArmFollowerConfig(
port="can0",
side="right", # or "left" for left arm
id="my_openarm_follower",
)
follower = OpenArmFollower(config)
follower.connect()
# Read current state
obs = follower.get_observation()
print(obs)
# Send action (position in degrees)
action = {
"joint_1.pos": 0.0,
"joint_2.pos": 0.0,
"joint_3.pos": 0.0,
"joint_4.pos": 45.0,
"joint_5.pos": 0.0,
"joint_6.pos": 0.0,
"joint_7.pos": 0.0,
"gripper.pos": 0.0,
}
follower.send_action(action)
follower.disconnect()
```
</hfoption>
</hfoptions>
### Leader Arm (Teleoperator)
The leader arm is used for teleoperation - manually moving it to control the follower arm.
<hfoptions id="leader">
<hfoption id="Command">
```bash
lerobot-calibrate \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_openarm_leader
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
config = OpenArmLeaderConfig(
port="can1",
id="my_openarm_leader",
manual_control=True, # Disable torque for manual movement
)
leader = OpenArmLeader(config)
leader.connect()
# Read current position (as action to send to follower)
action = leader.get_action()
print(action)
leader.disconnect()
```
</hfoption>
</hfoptions>
### Teleoperation
To teleoperate OpenArm with leader-follower control:
```bash
lerobot-teleoperate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader
```
### Bimanual Teleoperation
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
```bash
lerobot-teleoperate \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.id=my_bimanual_follower \
--teleop.type=bi_openarm_leader \
--teleop.left_arm_config.port=can2 \
--teleop.right_arm_config.port=can3 \
--teleop.id=my_bimanual_leader
```
### Recording Data
To record a dataset during teleoperation:
```bash
lerobot-record \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader \
--repo-id=my_hf_username/my_openarm_dataset \
--fps=30 \
--num-episodes=10
```
## Configuration Options
### Follower Configuration
| Parameter | Default | Description |
| --------------------- | --------- | ---------------------------------------------------------- |
| `port` | - | CAN interface (e.g., `can0`) |
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
| `max_relative_target` | `None` | Safety limit for relative target positions |
| `position_kp` | Per-joint | Position control proportional gains |
| `position_kd` | Per-joint | Position control derivative gains |
### Leader Configuration
| Parameter | Default | Description |
| ------------------ | --------- | ----------------------------------- |
| `port` | - | CAN interface (e.g., `can1`) |
| `manual_control` | `True` | Disable torque for manual movement |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
OpenArm uses Damiao motors with the following default configuration:
| Joint | Motor Type | Send ID | Recv ID |
| --------------------------- | ---------- | ------- | ------- |
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
| gripper | DM4310 | 0x08 | 0x18 |
## Troubleshooting
### No Response from Motors
1. Check power supply connections
2. Verify CAN wiring (CAN-H, CAN-L, GND)
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
### CAN Interface Not Found
Ensure the CAN interface is configured:
```bash
ip link show can0
```
## Resources
- [OpenArm Website](https://openarm.dev)
- [OpenArm Documentation](https://docs.openarm.dev)
- [OpenArm GitHub](https://github.com/enactic/openarm)
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
- [Damiao Motors and CAN Bus](./damiao)

328
docs/source/openarms.mdx Normal file
View File

@@ -0,0 +1,328 @@
# OpenArms Robot
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
## Hardware Overview
- **7 DOF per arm** (14 DOF total for dual arm setup)
- **1 gripper per arm** (2 grippers total)
- **Damiao motors** with 4 different types:
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
- **DM4340** for shoulder rotation and elbow (J3, J4)
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
- **24V power supply** required
- **CAN interface device**:
- **Linux**: Any SocketCAN-compatible adapter
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
- Proper CAN wiring (CANH, CANL, 120Ω termination)
## Motor Configuration
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|-------|-------|------------|---------------|-------------|-------------|
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
## Quick Start
```bash
# Install system dependencies
sudo apt install can-utils iproute2
# Install LeRobot with OpenArms support
pip install -e ".[openarms]"
```
## Setup Guide
### Step 1: Motor ID Configuration
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
Key points:
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
- Use the Damiao Debugging Tools to set these IDs
### Step 2: Setup CAN Interface
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
#### Linux (SocketCAN)
```bash
# Find your CAN interface
ip link show
# Configure can0, 1, 2, 3
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000
sudo ip link set can0 up
sudo ip link set can1 down
sudo ip link set can1 type can bitrate 1000000
sudo ip link set can1 up
sudo ip link set can2 down
sudo ip link set can2 type can bitrate 1000000
sudo ip link set can2 up
sudo ip link set can3 down
sudo ip link set can3 type can bitrate 1000000
sudo ip link set can3 up
# Verify configuration
ip link show can0
```
or run:
`examples/openarms/setup_can.sh`
### Testing canbus and motor connection
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
## Usage
### Basic Setup
```python
from lerobot.robots.openarms import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
# Configure for dual arm setup
config = OpenArmsFollowerConfig(
port="can0",
can_interface="socketcan", # Or "auto" for auto-detection
id="openarms_dual",
is_dual_arm=True,
)
robot = OpenArmsFollower(config)
robot.connect()
```
### Calibration
On first use, you'll need to calibrate the robot:
```python
robot.calibrate()
```
The calibration process will:
1. Disable torque on all motors
2. Ask you to position arms in **hanging position with grippers closed**
3. Set this as the zero position
4. Ask you to move each joint through its full range
5. Record min/max positions for each joint
6. Save calibration to file
### Reading Observations
The robot provides comprehensive state information:
```python
observation = robot.get_observation()
# Observation includes for each motor:
# - {motor_name}.pos: Position in degrees
# - {motor_name}.vel: Velocity in degrees/second
# - {motor_name}.torque: Motor torque
# - {camera_name}: Camera images (if configured)
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
```
### Sending Actions
```python
# Send target positions (in degrees)
action = {
"right_joint_1.pos": 45.0,
"right_joint_2.pos": -30.0,
# ... all joints
"right_gripper.pos": 45.0, # Half-closed
}
actual_action = robot.send_action(action)
```
### Gripper Control
```python
# Open gripper
robot.open_gripper(arm="right")
# Close gripper
robot.close_gripper(arm="right")
```
## Safety Features
### 1. Maximum Relative Target
Limits how far a joint can move in a single command to prevent sudden movements:
```python
config = OpenArmsFollowerConfig(
port="can0",
# Limit all joints to 10 degrees per command
max_relative_target=10.0,
# Or set per-motor limits
max_relative_target={
"right_joint_1": 15.0, # Slower moving joint
"right_joint_2": 10.0,
"right_gripper": 5.0, # Very slow gripper
}
)
```
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
### 2. Torque Limits
Control maximum torque output, especially important for grippers and teleoperation:
```python
config = OpenArmsFollowerConfig(
port="can0",
# Gripper torque limit (fraction of motor's max torque)
gripper_torque_limit=0.5, # 50% of max torque
)
```
Lower torque limits prevent damage when gripping delicate objects.
### 3. MIT Control Gains
Control responsiveness and stability via PID-like gains:
```python
config = OpenArmsFollowerConfig(
port="can0",
position_kp=10.0, # Position gain (higher = more responsive)
position_kd=0.5, # Velocity damping (higher = more damped)
)
```
**Guidelines**:
- **For following (robot)**: Higher gains for responsiveness
- `position_kp=10.0`, `position_kd=0.5`
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
- `manual_control=True` (torque disabled)
### 4. Velocity Limits
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
- Max velocity: 30 rad/s ≈ 1718°/s
The motors will automatically limit velocity to safe values.
## Teleoperation
### Leader Arm Setup
The leader arm is moved manually (torque disabled) to generate commands:
```python
from lerobot.teleoperators.openarms import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
config = OpenArmsLeaderConfig(
port="can1", # Separate CAN interface for leader
id="openarms_leader",
manual_control=True, # Torque disabled for manual movement
is_dual_arm=True,
)
leader = OpenArmsLeader(config)
leader.connect()
# Read current position as action
action = leader.get_action()
# action contains positions for all joints in degrees
```
### Safety Considerations for Teleoperation
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
2. **Enable max_relative_target** on follower to smooth abrupt movements
3. **Lower torque limits** on follower to prevent damage from tracking errors
4. **Test with one arm** before enabling dual arm teleoperation
5. **Have emergency stop** ready (power switch or CAN disable)
```python
# Recommended follower config for teleoperation
follower_config = OpenArmsFollowerConfig(
port="can0",
max_relative_target=5.0, # Small steps for smooth following
gripper_torque_limit=0.3, # Low torque for safety
position_kp=5.0, # Lower gains for gentler following
position_kd=0.3,
)
```
## Troubleshooting
### Motor Shaking/Unstable
- **Lower control gains**: Reduce `position_kp` and `position_kd`
- **Check calibration**: Re-run calibration procedure
- **Verify power**: Insufficient current can cause instability
- **Check mechanical**: Loose connections, binding, or damaged components
### CAN Bus Errors
```bash
# Check for errors
ip -s link show can0
# Reset CAN interface
sudo ip link set can0 down
sudo ip link set can0 up
```
### Control Mode
OpenArms uses **MIT control mode** which allows simultaneous control of:
- Position (degrees)
- Velocity (degrees/second)
- Torque (N·m)
- Position gain (Kp)
- Velocity damping (Kd)
### Communication
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
- **Frame format**: Standard 11-bit IDs
- **Update rate**: Typically 50-100 Hz depending on motor count
- **Latency**: ~10-20ms per motor command
## References
- [OpenArm Official Documentation](https://docs.openarm.dev/)
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
- [Enactic GitHub](https://github.com/enactic/openarm_can)

291
docs/source/rac.mdx Normal file
View File

@@ -0,0 +1,291 @@
# RaC: Recovery and Correction Training
RaC (Recovery and Correction) is a human-in-the-loop data collection and training paradigm that improves robot policy performance on long-horizon tasks by explicitly teaching recovery and correction behaviors.
**Key References:**
- [RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction](https://arxiv.org/abs/2509.07953) (Hu et al., 2025)
- [HG-DAgger: Interactive Imitation Learning with Human Experts](https://arxiv.org/abs/1810.02890) (Kelly et al., 2019)
- [π0.6: a VLA That Learns From Experience](https://pi.website/blog/pistar06) (Physical Intelligence, 2025)
- [SARM: Stage-Aware Reward Modeling](https://arxiv.org/abs/2509.25358) (Chen et al., 2025)
---
## Why RaC? The Problem with Standard Data Collection
### Standard Behavioral Cloning Data Collection Limitations
Standard behavior cloning trains policies on successful demonstrations. This approach can be sensitive to distribution shift and compounding errors. Because during deployment small errors can cascade and push the robot into states never seen during training.
This is where RaC and methods like Dagger and HG-DAgger come in.
### Prior Human-in-the-Loop Methods
**DAgger** (Dataset Aggregation) addresses distribution shift by:
- Running the novice policy to collect states
- Querying expert for correct actions at those states
- Aggregating new labels into training set
**HG-DAgger** (Human-Gated DAgger) improves on DAgger by:
- Giving human full control authority during interventions
- Human takes over when unsafe, provides correction, returns control
- Better action labels because human has uninterrupted control
### RaC
RaC explicitly collects **recovery + correction** data:
```
BC/DAgger: policy → mistake → human corrects → continue
RaC: policy → mistake → human RECOVERS (teleop back) → CORRECTS → END
```
The critical insight is **Rule 1 (Recover then Correct)**:
- Every intervention starts with human teleoperating back to an in-distribution state
- Then human provides correction to complete the current subtask
- Both segments are recorded as training data
- This teaches the policy: "when things go wrong, go back and retry"
**Rule 2 (Terminate after Intervention)**:
- Episode ends after correction completes
- Avoids mixed policy/human data on later subtasks
- Keeps data distribution clean
---
## Comparison Table
| Method | Data Type | Recovery Behavior | Correction Behavior |
|--------|-----------|-------------------|---------------------|
| BC | Success only | ✗ | ✗ |
| DAgger | Success + corrections | ✗ | ✓ |
| HG-DAgger | Success + corrections | Sometimes | ✓ |
| RaC | Success + recovery + correction | ✓ Explicit | ✓ |
---
## The RaC Pipeline
```
┌─────────────────────────────────────────────────────────────────────────┐
│ RaC Training Pipeline │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. PRE-TRAINING (Standard BC) │
│ └─> Train initial policy on clean demonstrations │
│ │
│ 2. RAC DATA COLLECTION (Human-in-the-loop) │
│ ├─> Policy runs autonomously │
│ ├─> Human monitors and intervenes when failure imminent │
│ │ ├─> RECOVERY: Human teleoperates robot back to good state │
│ │ └─> CORRECTION: Human completes the current subtask │
│ └─> Episode terminates after correction (Rule 2) │
│ │
│ 3. REWARD LABELING (Optional: SARM) │
│ └─> Compute progress rewards for advantage-weighted training │
│ │
│ 4. FINE-TUNING │
│ └─> Train on combined demos + RaC data (optionally with RA-BC) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
```
---
## Step-by-Step Guide
### Step 1: Pre-train a Base Policy
First, train a policy on your demonstration dataset:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/demo-dataset \
--policy.type=pi0 \
--output_dir=outputs/pretrain \
--batch_size=32 \
--steps=50000
```
### Step 2: Collect RaC Data
Run the RaC data collection script with your pre-trained policy:
```bash
python examples/rac/rac_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rac-dataset \
--dataset.single_task="Pick up the cube and place it in the bowl" \
--dataset.num_episodes=50
```
**Controls (Keyboard + Foot Pedal):**
| Key / Pedal | Action |
|-------------|--------|
| **SPACE** / Right pedal | Pause policy (teleop mirrors robot, no recording) |
| **c** / Left pedal | Take control (start correction, recording resumes) |
| **→** / Right pedal | End episode (save) - when in correction mode |
| **←** | Re-record episode |
| **ESC** | Stop session and push to hub |
| Any key/pedal during reset | Start next episode |
**The RaC Protocol:**
1. Watch the policy run autonomously (teleop is idle/free)
2. When you see imminent failure, press **SPACE** or **right pedal** to pause
- Policy stops
- Teleoperator moves to match robot position (torque enabled)
- No frames recorded during pause
3. Press **c** or **left pedal** to take control
- Teleoperator torque disabled, free to move
- **RECOVERY**: Teleoperate back to a good state
- **CORRECTION**: Complete the subtask
- All movements are recorded
4. Press **→** or **right pedal** to save and end episode
5. **RESET**: Teleop moves to robot position, you can move robot to starting position
6. Press any key/pedal to start next episode
The recovery and correction segments teach the policy how to recover from errors.
**Foot Pedal Setup (Linux):**
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
```bash
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
```
### Step 3: (Optional) Compute SARM Rewards
For advantage-weighted training (RA-BC / Pi0.6-style), compute SARM progress values:
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
--dataset-repo-id your-username/rac-dataset \
--reward-model-path your-username/sarm-model \
--head-mode sparse \
--push-to-hub
```
### Step 4: Fine-tune Policy
Fine-tune on the RaC data:
```bash
# Without RA-BC (standard fine-tuning)
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/rac-dataset \
--policy.type=pi0 \
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
--output_dir=outputs/rac_finetune \
--steps=20000
# With RA-BC (advantage-weighted, Pi0.6-style)
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/rac-dataset \
--policy.type=pi0 \
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
--output_dir=outputs/rac_finetune_rabc \
--use_rabc=true \
--rabc_kappa=0.01 \
--steps=20000
```
---
## Connection to Pi0.6 / RECAP
Pi0.6's RECAP method shares similar principles:
- Collect autonomous rollouts + expert interventions
- Use value function to compute **advantages**: A(s,a) = V(s') - V(s)
- **Advantage conditioning**: Weight training based on expected improvement
In LeRobot, we can use **SARM** as the value function:
- SARM progress φ(s) ∈ [0,1] measures task completion
- Progress delta = φ(s') - φ(s) approximates advantage
- RA-BC uses these to weight training samples (higher weight for good corrections)
---
## Tips for Effective RaC Collection
### When to Intervene
Intervene when you see:
- Robot about to make an irreversible mistake
- Robot hesitating or showing uncertain behavior
- Robot deviating from expected trajectory
### Recovery: Teleoperating Back to Good State
During recovery, teleoperate the robot back to a state where:
- The robot is in a familiar, in-distribution configuration
- The current subtask can still be completed
- The recovery trajectory itself is informative training data
### Quality of Corrections
During correction:
- Provide **confident, clean** trajectories
- Complete the current subtask fully
- Don't overcorrect or add unnecessary movements
---
## Iterative Improvement
RaC can be applied iteratively:
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Policy v0 (demos) │
│ ↓ │
│ RaC Collection (target current failure modes) → Policy v1 │
│ ↓ │
│ RaC Collection (target new failure modes) → Policy v2 │
│ ↓ │
│ ... (repeat until satisfactory performance) │
└─────────────────────────────────────────────────────────────────────────┘
```
Each iteration:
1. Deploy current policy
2. Collect RaC interventions on failure cases
3. Fine-tune on accumulated data
---
## References
```bibtex
@article{hu2025rac,
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
journal={arXiv preprint arXiv:2509.07953},
year={2025}
}
@article{kelly2019hgdagger,
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
journal={arXiv preprint arXiv:1810.02890},
year={2019}
}
@article{pi2025recap,
title={π0.6: a VLA That Learns From Experience},
author={Physical Intelligence},
year={2025}
}
@article{chen2025sarm,
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
journal={arXiv preprint arXiv:2509.25358},
year={2025}
}
```

View File

@@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy.
## Running in Simulation Mode (MuJoCo)
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
### Calibrate Exoskeleton Teleoperator
```bash
lerobot-calibrate \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
```
### Teleoperate in Simulation
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset in Simulation
```bash
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
---
## Running on Real Robot
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
### Start the Camera Server
On the robot, start the ZMQ image server:
```bash
python src/lerobot/cameras/zmq/image_server.py
```
Keep this running in a separate terminal for camera streaming during recording.
### Teleoperate Real Robot
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset on Real Robot
```bash
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
**Note**: Update `server_address` to match your robot's camera server IP.
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
---
## Additional Resources

View File

@@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig):
actions = dataset.hf_dataset.select_columns(ACTION)
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
try:
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
robot.disconnect()
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
finally:
robot.disconnect()
if __name__ == "__main__":

View File

@@ -78,40 +78,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -120,24 +104,42 @@ def main():
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Save episode
dataset.save_episode()
recorded_episodes += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
recorded_episodes += 1
dataset.finalize()
dataset.push_to_hub()
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -74,40 +74,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_record")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
@@ -115,26 +98,44 @@ def main():
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Save episode
dataset.save_episode()
recorded_episodes += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
recorded_episodes += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -42,25 +42,27 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Send action to robot
_ = robot.send_action(action)
# Send action to robot
_ = robot.send_action(action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
robot.disconnect()
if __name__ == "__main__":

View File

@@ -0,0 +1,416 @@
#!/usr/bin/env python3
"""
Comprehensive debug script for OpenArms CAN FD communication.
Tests all 4 CAN interfaces with CAN FD support.
"""
import can
import time
import sys
import subprocess
def check_can_interface(port):
"""Check if CAN interface is UP and configured."""
try:
result = subprocess.run(['ip', 'link', 'show', port],
capture_output=True, text=True)
if result.returncode != 0:
return False, "Interface not found", None
output = result.stdout
if 'UP' not in output:
return False, "Interface is DOWN", None
# Check if CAN FD is enabled
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
return True, "Interface is UP", is_fd
except FileNotFoundError:
return None, "Cannot check (ip command not found)", None
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
"""
Test a single motor and return all responses.
Returns:
list of (arbitration_id, data) tuples for all responses received
"""
# Send enable command
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=use_fd
)
try:
bus.send(enable_msg)
except Exception as e:
return None, f"Send error: {e}"
# Listen for responses
responses = []
start_time = time.time()
while time.time() - start_time < timeout:
msg = bus.recv(timeout=0.1)
if msg:
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
# Send disable command
disable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
is_extended_id=False,
is_fd=use_fd
)
try:
bus.send(disable_msg)
except:
pass
return responses, None
def test_interface(port, interface_type="socketcan", use_can_fd=True):
"""Test all 8 motors on a single CAN interface."""
results = {
'interface': port,
'status': None,
'is_fd': use_can_fd,
'motors': {}
}
# Check interface status
status_ok, status_msg, interface_has_fd = check_can_interface(port)
if interface_has_fd is not None:
results['interface_fd_enabled'] = interface_has_fd
if use_can_fd and not interface_has_fd:
status_msg += " (CAN FD NOT enabled on interface!)"
elif interface_has_fd:
status_msg += " (CAN FD enabled)"
results['status'] = status_msg
if status_ok is False:
return results
# Try to connect
try:
if use_can_fd:
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
bus = can.interface.Bus(
channel=port,
interface=interface_type,
bitrate=1000000,
data_bitrate=5000000,
fd=True
)
else:
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
bus = can.interface.Bus(
channel=port,
interface=interface_type,
bitrate=1000000
)
except Exception as e:
results['status'] = f"Connection failed: {e}"
return results
try:
# Clear any pending messages
while bus.recv(timeout=0.01):
pass
# Test each motor (0x01 to 0x08)
for motor_id in range(0x01, 0x09):
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
if error:
results['motors'][motor_id] = {'error': error}
elif responses:
results['motors'][motor_id] = {
'found': True,
'responses': responses
}
else:
results['motors'][motor_id] = {
'found': False,
'responses': []
}
time.sleep(0.05) # Small delay between motors
finally:
bus.shutdown()
return results
def print_results(all_results):
"""Print formatted results for all interfaces."""
print("SUMMARY - Motors Found on Each Interface")
motor_names = {
0x01: "joint_1 (Shoulder pan)",
0x02: "joint_2 (Shoulder lift)",
0x03: "joint_3 (Shoulder rotation)",
0x04: "joint_4 (Elbow flex)",
0x05: "joint_5 (Wrist roll)",
0x06: "joint_6 (Wrist pitch)",
0x07: "joint_7 (Wrist rotation)",
0x08: "gripper",
}
total_found = 0
for result in all_results:
interface = result['interface']
status = result['status']
print(f"{interface}: {status}")
if result.get('is_fd'):
print(f" Mode: CAN FD")
else:
print(f" Mode: CAN 2.0")
if 'Connection failed' in status or 'DOWN' in status:
print(f" ⚠ Cannot test {interface}")
continue
motors_found = 0
for motor_id in range(0x01, 0x09):
motor_data = result['motors'].get(motor_id, {})
motor_name = motor_names.get(motor_id, "Unknown")
if motor_data.get('error'):
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
elif motor_data.get('found'):
motors_found += 1
total_found += 1
responses = motor_data['responses']
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
for resp_id, data, is_fd in responses:
data_hex = data.hex()
fd_flag = " [FD]" if is_fd else " [2.0]"
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
else:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
# Overall summary
print("OVERALL SUMMARY")
print(f"Total motors found across all interfaces: {total_found}")
# Analyze configuration
print("DIAGNOSIS")
for result in all_results:
interface = result['interface']
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
if motors_found == 0:
print(f"\n{interface}: NO MOTORS FOUND")
print(" Possible issues:")
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
print(" 4. CANH/CANL wiring issue")
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
# Check FD mismatch
if result.get('is_fd') and not result.get('interface_fd_enabled'):
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
elif motors_found < 8:
print(f"\n{interface}: Only {motors_found}/8 motors responding")
print(" Check power and connections for missing motors")
else:
print(f"\n{interface}: All 8 motors responding correctly!")
# Check for unexpected response IDs
print("RESPONSE ID ANALYSIS")
for result in all_results:
interface = result['interface']
unexpected = []
for motor_id, motor_data in result['motors'].items():
if motor_data.get('found'):
expected_id = motor_id + 0x10
actual_ids = [resp[0] for resp in motor_data['responses']]
if expected_id not in actual_ids:
unexpected.append((motor_id, actual_ids))
if unexpected:
print(f"\n{interface}: Unexpected response IDs detected")
for motor_id, actual_ids in unexpected:
expected_id = motor_id + 0x10
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
f"got {[f'0x{id:02X}' for id in actual_ids]}")
print(" → Motor Master IDs need reconfiguration")
else:
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
if motors_found > 0:
print(f"\n{interface}: All responding motors use correct IDs")
def test_communication_speed(interface, motor_id, num_iterations=100):
"""
Test communication speed with a motor.
Returns:
tuple: (hz, avg_latency_ms) or (None, None) if test failed
"""
try:
# Connect to interface
bus = can.interface.Bus(
channel=interface,
interface="socketcan",
bitrate=1000000,
data_bitrate=5000000,
fd=True
)
# Send refresh commands and measure round-trip time
latencies = []
successful = 0
for _ in range(num_iterations):
start = time.perf_counter()
# Send enable command (lightweight operation)
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=True
)
bus.send(enable_msg)
# Wait for response
msg = bus.recv(timeout=0.1)
if msg:
latency = (time.perf_counter() - start) * 1000 # Convert to ms
latencies.append(latency)
successful += 1
bus.shutdown()
if successful > 0:
avg_latency = sum(latencies) / len(latencies)
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
return hz, avg_latency
return None, None
except Exception as e:
print(f" Speed test error: {e}")
return None, None
def main():
"""Main function to test all CAN interfaces with CAN FD."""
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
print("Testing motors 0x01-0x08 on each interface")
print()
print("Make sure:")
print(" ✓ Motors are powered (24V)")
print(" ✓ CAN interfaces configured with FD mode:")
print(" ./examples/openarms/setup_can.sh")
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
print()
input("Press ENTER to start testing...")
# Test all 4 interfaces with CAN FD
all_results = []
for i in range(4):
interface = f"can{i}"
print(f"Testing {interface}...")
result = test_interface(interface, use_can_fd=True)
all_results.append(result)
# Quick status
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
print(f"{interface}: {result['status']}")
else:
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
print(f" {interface}: {motors_found}/8 motors found")
time.sleep(0.2)
# Print detailed results
print_results(all_results)
print("Testing Complete!")
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
if all_found == 0:
print("\n⚠️ CRITICAL: No motors found on any interface!")
print("\nTop issues to check:")
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
print(" 3. Missing termination resistors")
print("\nTry:")
print(" a) Check motor parameters with Damiao Debugging Tools")
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
else:
# Run speed test on interfaces with motors
print("COMMUNICATION SPEED TEST")
print("\nTesting maximum communication frequency...")
for result in all_results:
interface = result['interface']
# Find first responding motor
responding_motor = None
for motor_id, motor_data in result['motors'].items():
if motor_data.get('found'):
responding_motor = motor_id
break
if responding_motor:
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
if hz:
print(f" ✓ Max frequency: {hz:.1f} Hz")
print(f" ✓ Avg latency: {latency:.2f} ms")
print(f" ✓ Commands per second: ~{int(hz)}")
else:
print(f" ✗ Speed test failed")
else:
print(f"\n{interface}: No motors found, skipping speed test")
print()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\nTesting interrupted by user.")
sys.exit(1)
except Exception as e:
print(f"\nUnexpected error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -0,0 +1,360 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenArms Policy Evaluation
Evaluates a trained policy on the OpenArms robot by running inference and recording
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
Example usage:
python examples/openarms/evaluate.py
"""
import time
from pathlib import Path
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.processor import make_default_processors
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
HF_MODEL_ID = "lerobot-data-collection/level1_rac2_100k" # TODO: Replace with your trained model
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval_raccc3" # TODO: Replace with your eval dataset name
TASK_DESCRIPTION = "Fold the T-shirt properly" # TODO: Replace with your task, this should match!!
NUM_EPISODES = 1
FPS = 30
EPISODE_TIME_SEC = 1000
RESET_TIME_SEC = 60
# Robot CAN interfaces
FOLLOWER_LEFT_PORT = "can0"
FOLLOWER_RIGHT_PORT = "can1"
# If enabled, you can manually reset the environment between evaluation episodes
USE_LEADER_FOR_RESETS = False # Set to False if you don't want to use leader
LEADER_LEFT_PORT = "can2"
LEADER_RIGHT_PORT = "can3"
# Camera configuration
CAMERA_CONFIG = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=1280, height=720, fps=FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=1280, height=720, fps=FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video2", width=640, height=480, fps=FPS),
}
def main():
"""Main evaluation function."""
print("OpenArms Policy Evaluation")
print(f"\nModel: {HF_MODEL_ID}")
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
print(f"Task: {TASK_DESCRIPTION}")
print(f"Episodes: {NUM_EPISODES}")
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
print(f"Reset Duration: {RESET_TIME_SEC}s")
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
follower_config = OpenArmsFollowerConfig(
port_left=FOLLOWER_LEFT_PORT,
port_right=FOLLOWER_RIGHT_PORT,
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
cameras=CAMERA_CONFIG,
)
follower = OpenArmsFollower(follower_config)
follower.connect(calibrate=False)
if not follower.is_connected:
raise RuntimeError("Follower robot failed to connect!")
leader = None
if USE_LEADER_FOR_RESETS:
leader_config = OpenArmsLeaderConfig(
port_left=LEADER_LEFT_PORT,
port_right=LEADER_RIGHT_PORT,
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
leader = OpenArmsLeader(leader_config)
leader.connect(calibrate=False)
if not leader.is_connected:
raise RuntimeError("Leader robot failed to connect!")
# Enable gravity compensation
if leader.pin_robot is not None:
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
else:
print(f"Leader connected but gravity compensation unavailable (no URDF)")
# Build default processors for action and observation
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Build dataset features from robot features and processors
# For actions, only include positions (no velocity or torque)
action_features_hw = {}
for key, value in follower.action_features.items():
if key.endswith(".pos"):
action_features_hw[key] = value
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=follower.observation_features),
use_videos=True,
),
)
# Check if dataset already exists
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
if dataset_path.exists():
print(f"Evaluation dataset already exists at: {dataset_path}")
print("This will append new episodes to the existing dataset.")
choice = input(" Continue? (y/n): ").strip().lower()
if choice != 'y':
print(" Aborting evaluation.")
follower.disconnect()
if leader:
leader.disconnect()
return
# Create dataset
dataset = LeRobotDataset.create(
repo_id=HF_EVAL_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=follower.name,
use_videos=True,
image_writer_processes=0,
image_writer_threads=12,
)
# Load policy config from pretrained model and create policy using factory
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
policy = make_policy(policy_config, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
preprocessor_overrides={
"device_processor": {"device": str(policy.config.device)}
},
)
print(f"\nRunning evaluation...")
# Initialize keyboard listener and visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="openarms_evaluation")
episode_idx = 0
try:
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
print(f"\nRunning inference for episode {episode_idx + 1}...")
# Run inference with policy
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Handle re-recording
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
dataset.save_episode()
episode_idx += 1
# Reset environment between episodes (if not last episode)
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
if USE_LEADER_FOR_RESETS and leader:
log_say("Reset the environment using leader arms")
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
# Use leader for manual reset with gravity compensation
import numpy as np
dt = 1 / FPS
reset_start_time = time.perf_counter()
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
if events["exit_early"] or events["stop_recording"]:
break
loop_start = time.perf_counter()
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity and friction torques
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=1.0
)
# Combine torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply compensation
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor, kp=0.0, kd=kd,
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor, kp=0.0, kd=kd,
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower
follower_action = {}
for joint in leader_positions_deg.keys():
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
if follower_action:
follower.send_action(follower_action)
# Maintain loop rate
loop_duration = time.perf_counter() - loop_start
sleep_time = dt - loop_duration
if sleep_time > 0:
time.sleep(sleep_time)
print("Reset complete")
else:
log_say("Waiting for manual reset")
print(f"Manually reset the environment and press ENTER to continue")
input("Press ENTER when ready...")
print(f"Evaluation complete! {episode_idx} episodes recorded")
log_say("Evaluation complete", blocking=True)
except KeyboardInterrupt:
print("\n\nEvaluation interrupted by user")
finally:
if leader:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
if listener is not None:
listener.stop()
dataset.finalize()
print("\nUploading to Hugging Face Hub...")
dataset.push_to_hub(private=True)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,703 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenArms Policy Evaluation with Real-Time Chunking (RTC)
Evaluates a trained policy on the OpenArms robot using RTC for smooth, continuous motion.
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion
despite high inference latency by asynchronously generating action chunks.
Features:
- Thread-based asynchronous action generation and execution
- RTC for smooth transitions between action chunks
- Dataset recording for evaluation episodes
Example usage:
python examples/openarms/evaluate_with_rtc.py
# With custom RTC parameters
python examples/openarms/evaluate_with_rtc.py \
--rtc.execution_horizon=12 \
--rtc.max_guidance_weight=10.0
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from pathlib import Path
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor import make_default_processors
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging, log_say
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# Default Configuration Constants
# ============================================================================
DEFAULT_HF_MODEL_ID = "lerobot-data-collection/level1_rac3_100k"
DEFAULT_HF_EVAL_DATASET_ID = "lerobot-data-collection/test"
DEFAULT_TASK_DESCRIPTION = "Fold the T-shirt properly"
DEFAULT_NUM_EPISODES = 1
DEFAULT_FPS = 30
DEFAULT_EPISODE_TIME_SEC = 1000
DEFAULT_RESET_TIME_SEC = 60
DEFAULT_FOLLOWER_LEFT_PORT = "can0"
DEFAULT_FOLLOWER_RIGHT_PORT = "can1"
DEFAULT_CAMERA_CONFIG = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=1280, height=720, fps=DEFAULT_FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video4", width=1280, height=720, fps=DEFAULT_FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video2", width=640, height=480, fps=DEFAULT_FPS),
}
# ============================================================================
# Thread-Safe Robot Wrapper
# ============================================================================
class RobotWrapper:
"""Thread-safe wrapper for robot operations."""
def __init__(self, robot: OpenArmsFollower):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: dict) -> None:
with self.lock:
self.robot.send_action(action)
@property
def observation_features(self) -> dict:
with self.lock:
return self.robot.observation_features
@property
def action_features(self) -> dict:
with self.lock:
return self.robot.action_features
@property
def name(self) -> str:
return self.robot.name
# ============================================================================
# Configuration
# ============================================================================
@dataclass
class OpenArmsRTCEvalConfig(HubMixin):
"""Configuration for OpenArms evaluation with RTC."""
policy: PreTrainedConfig | None = None
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
enabled=True,
execution_horizon=20,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
)
)
model_id: str = DEFAULT_HF_MODEL_ID
eval_dataset_id: str = DEFAULT_HF_EVAL_DATASET_ID
task: str = DEFAULT_TASK_DESCRIPTION
num_episodes: int = DEFAULT_NUM_EPISODES
fps: float = DEFAULT_FPS
episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC
reset_time_sec: float = DEFAULT_RESET_TIME_SEC
follower_left_port: str = DEFAULT_FOLLOWER_LEFT_PORT
follower_right_port: str = DEFAULT_FOLLOWER_RIGHT_PORT
device: str = "cuda"
# Should be higher than inference_delay + execution_horizon
action_queue_size_to_get_new_actions: int = 30
record_dataset: bool = True
push_to_hub: bool = True
interpolation: bool = True
use_torch_compile: bool = False
torch_compile_backend: str = "inductor"
torch_compile_mode: str = "default"
torch_compile_disable_cudagraphs: bool = True
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
self.model_id = policy_path
elif self.model_id:
self.policy = PreTrainedConfig.from_pretrained(self.model_id)
self.policy.pretrained_path = self.model_id
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
# ============================================================================
# Action Generation Thread
# ============================================================================
def get_actions_thread(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: OpenArmsRTCEvalConfig,
episode_active: Event,
):
"""Thread function to asynchronously generate action chunks from the policy."""
try:
logger.info("[GET_ACTIONS] Starting action generation thread")
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / cfg.fps
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
policy_device = policy.config.device
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None,
preprocessor_overrides={
"device_processor": {"device": cfg.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if not episode_active.is_set():
time.sleep(0.01)
continue
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
hw_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task]
obs_with_policy_features["robot_type"] = robot.name
preprocessed_obs = preprocessor(obs_with_policy_features)
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] action_queue_size_to_get_new_actions too small. "
"Should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
logger.debug(
f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, "
f"delay={new_delay}, queue_size={action_queue.qsize()}"
)
else:
time.sleep(0.01)
logger.info("[GET_ACTIONS] Action generation thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
logger.error(traceback.format_exc())
shutdown_event.set()
sys.exit(1)
# ============================================================================
# Action Execution Thread
# ============================================================================
def actor_thread(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: OpenArmsRTCEvalConfig,
episode_active: Event,
dataset: LeRobotDataset | None,
dataset_lock: Lock,
teleop_action_processor,
robot_observation_processor,
):
"""Thread function to execute actions on the robot."""
try:
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
if cfg.interpolation:
interp_factor = 2
robot_interval = 1.0 / (cfg.fps * interp_factor)
logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.fps * interp_factor}Hz (2x)")
else:
interp_factor = 1
robot_interval = 1.0 / cfg.fps
logger.info(f"[ACTOR] Interpolation OFF: policy={cfg.fps}Hz, robot={cfg.fps}Hz")
prev_action: Tensor | None = None
interpolated_actions: list[Tensor] = []
interp_idx = 0
robot_send_count = 0
policy_consume_count = 0
last_hz_print = time.perf_counter()
last_dataset_time = 0.0
while not shutdown_event.is_set():
if not episode_active.is_set():
prev_action = None
interpolated_actions = []
interp_idx = 0
robot_send_count = 0
policy_consume_count = 0
last_hz_print = time.perf_counter()
time.sleep(0.01)
continue
start_time = time.perf_counter()
if interp_idx >= len(interpolated_actions):
new_action = action_queue.get()
if new_action is not None:
current_action = new_action.cpu()
policy_consume_count += 1
if cfg.interpolation and prev_action is not None:
mid = prev_action + 0.5 * (current_action - prev_action)
interpolated_actions = [mid, current_action]
else:
interpolated_actions = [current_action]
prev_action = current_action
interp_idx = 0
if interp_idx < len(interpolated_actions):
action_to_send = interpolated_actions[interp_idx]
interp_idx += 1
action_dict = {}
for i, key in enumerate(action_keys):
if i < len(action_to_send):
action_dict[key] = action_to_send[i].item()
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
robot_send_count += 1
if cfg.record_dataset and dataset is not None:
now = time.perf_counter()
if now - last_dataset_time >= (1.0 / cfg.fps):
last_dataset_time = now
with dataset_lock:
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
action_for_dataset = teleop_action_processor((action_dict, None))
frame = {}
for key, value in obs_processed.items():
frame[f"observation.{key}"] = value
for key, value in action_for_dataset.items():
frame[f"action.{key}"] = value
frame["task"] = cfg.task
dataset.add_frame(frame)
now = time.perf_counter()
if now - last_hz_print >= 5.0:
elapsed = now - last_hz_print
actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0
actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0
logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}")
robot_send_count = 0
policy_consume_count = 0
last_hz_print = now
dt_s = time.perf_counter() - start_time
sleep_time = max(0, robot_interval - dt_s - 0.001)
if sleep_time > 0:
time.sleep(sleep_time)
logger.info("[ACTOR] Shutting down")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception: {e}")
logger.error(traceback.format_exc())
shutdown_event.set()
sys.exit(1)
# ============================================================================
# Main Evaluation Function
# ============================================================================
def _apply_torch_compile(policy, cfg: OpenArmsRTCEvalConfig):
"""Apply torch.compile to the policy's predict_action_chunk method."""
if policy.name in ["pi05", "pi0"]:
return policy
try:
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def main(cfg: OpenArmsRTCEvalConfig):
"""Main evaluation function with RTC."""
init_logging()
print("=" * 60)
print("OpenArms Policy Evaluation with RTC")
print("=" * 60)
print(f"\nModel: {cfg.model_id}")
print(f"Evaluation Dataset: {cfg.eval_dataset_id}")
print(f"Task: {cfg.task}")
print(f"Episodes: {cfg.num_episodes}")
print(f"Episode Duration: {cfg.episode_time_sec}s")
print(f"RTC Enabled: {cfg.rtc.enabled}")
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
print(f"Policy Hz: {cfg.fps}")
print(f"Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}")
print(f"Interpolation: {cfg.interpolation}")
print(f"Device: {cfg.device}")
print("=" * 60)
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
episode_active = Event()
# Initialize Robot
follower_config = OpenArmsFollowerConfig(
port_left=cfg.follower_left_port,
port_right=cfg.follower_right_port,
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
cameras=DEFAULT_CAMERA_CONFIG,
)
follower = OpenArmsFollower(follower_config)
follower.connect(calibrate=False)
if not follower.is_connected:
raise RuntimeError("Follower robot failed to connect!")
robot = RobotWrapper(follower)
logger.info("Follower robot connected")
# Build Processors and Dataset Features
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
action_features_hw = {}
for key, value in follower.action_features.items():
if key.endswith(".pos"):
action_features_hw[key] = value
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=follower.observation_features),
use_videos=True,
),
)
# Create or Load Dataset
dataset = None
dataset_lock = Lock()
if cfg.record_dataset:
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / cfg.eval_dataset_id
if dataset_path.exists():
logger.info(f"Evaluation dataset exists at: {dataset_path}")
logger.info("New episodes will be appended.")
choice = input("Continue? (y/n): ").strip().lower()
if choice != "y":
logger.info("Aborting evaluation.")
follower.disconnect()
return
dataset = LeRobotDataset.create(
repo_id=cfg.eval_dataset_id,
fps=int(cfg.fps),
features=dataset_features,
robot_type=follower.name,
use_videos=True,
image_writer_processes=0,
image_writer_threads=12,
)
logger.info(f"Dataset created: {cfg.eval_dataset_id}")
# Load Policy
logger.info(f"Loading policy from: {cfg.model_id}")
policy_class = get_policy_class(cfg.policy.type)
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type in ["pi05", "pi0"]:
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
policy.config.rtc_config = cfg.rtc
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
logger.info(f"Policy loaded: {policy.name}")
# Create Action Queue and Start Threads
action_queue = ActionQueue(cfg.rtc)
get_actions_t = Thread(
target=get_actions_thread,
args=(
policy,
robot,
robot_observation_processor,
action_queue,
shutdown_event,
cfg,
episode_active,
),
daemon=True,
name="GetActions",
)
get_actions_t.start()
logger.info("Started action generation thread")
actor_t = Thread(
target=actor_thread,
args=(
robot,
robot_action_processor,
action_queue,
shutdown_event,
cfg,
episode_active,
dataset,
dataset_lock,
teleop_action_processor,
robot_observation_processor,
),
daemon=True,
name="Actor",
)
actor_t.start()
logger.info("Started actor thread")
# Run Evaluation Episodes
episode_idx = 0
try:
while episode_idx < cfg.num_episodes and not shutdown_event.is_set():
log_say(f"Evaluating episode {episode_idx + 1} of {cfg.num_episodes}")
logger.info(f"\n{'='*40}")
logger.info(f"Episode {episode_idx + 1} / {cfg.num_episodes}")
logger.info(f"{'='*40}")
action_queue = ActionQueue(cfg.rtc)
episode_active.set()
episode_start_time = time.time()
while (time.time() - episode_start_time) < cfg.episode_time_sec:
if shutdown_event.is_set():
break
elapsed = time.time() - episode_start_time
if int(elapsed) % 10 == 0 and int(elapsed) > 0:
logger.info(
f"[MAIN] Episode progress: {elapsed:.0f}/{cfg.episode_time_sec}s, "
f"queue_size={action_queue.qsize()}"
)
time.sleep(0.5)
episode_active.clear()
logger.info(f"Episode {episode_idx + 1} completed")
if cfg.record_dataset and dataset is not None:
with dataset_lock:
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
logger.info(
f"Saving episode {episode_idx + 1} "
f"({dataset.episode_buffer['size']} frames)"
)
dataset.save_episode()
episode_idx += 1
# Manual reset between episodes
if not shutdown_event.is_set() and episode_idx < cfg.num_episodes:
log_say("Waiting for manual reset")
logger.info("Manually reset the environment and press ENTER to continue")
input("Press ENTER when ready...")
logger.info(f"Evaluation complete! {episode_idx} episodes recorded")
log_say("Evaluation complete", blocking=True)
except KeyboardInterrupt:
logger.info("\n\nEvaluation interrupted by user")
finally:
shutdown_event.set()
episode_active.clear()
if get_actions_t.is_alive():
logger.info("Waiting for action generation thread to finish...")
get_actions_t.join(timeout=5.0)
if actor_t.is_alive():
logger.info("Waiting for actor thread to finish...")
actor_t.join(timeout=5.0)
follower.disconnect()
logger.info("Follower disconnected")
if cfg.record_dataset and dataset is not None:
dataset.finalize()
if cfg.push_to_hub:
logger.info("Uploading to Hugging Face Hub...")
dataset.push_to_hub(private=True)
logger.info("Cleanup completed")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,216 @@
import time
import numpy as np
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
# Friction model parameters from OpenArms config/follower.yaml
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
FRICTION_PARAMS = {
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
}
# Constants from OpenArms C++ implementation
AMP_TMP = 1.0
COEF_TMP = 0.1
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
"""
Compute friction torque for a single motor using the tanh friction model.
Args:
velocity_rad_per_sec: Angular velocity in rad/s
motor_index: Index of the motor (0-7)
Returns:
Friction torque in N·m (scaled for stability)
"""
Fc = FRICTION_PARAMS["Fc"][motor_index]
k = FRICTION_PARAMS["k"][motor_index]
Fv = FRICTION_PARAMS["Fv"][motor_index]
Fo = FRICTION_PARAMS["Fo"][motor_index]
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
friction_torque = (
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
Fv * velocity_rad_per_sec +
Fo
)
# Scale down friction compensation for stability at lower control rates
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
friction_torque *= FRICTION_SCALE
return friction_torque
def main() -> None:
config = OpenArmsFollowerConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0,
)
print("Initializing robot...")
follower = OpenArmsFollower(config)
follower.connect(calibrate=True)
print(f"Applying friction compensation")
print(" 1. Support the arm before starting")
print(" 2. The arm will be held in place by friction compensation")
print(" 3. You should be able to move it with gentle force")
print("\nPress ENTER when ready to start...")
input()
print(f"✓ Motors enabled")
print("\nStarting friction compensation loop...")
print("Press Ctrl+C to stop\n")
loop_times = []
last_print_time = time.perf_counter()
# Motor name to index mapping
motor_name_to_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
try:
while True:
loop_start = time.perf_counter()
# Get current joint positions and velocities from robot
obs = follower.get_observation()
# Extract velocities in degrees per second
velocities_deg_per_sec = {}
positions_deg = {}
for motor in follower.bus_right.motors:
vel_key = f"right_{motor}.vel"
pos_key = f"right_{motor}.pos"
if vel_key in obs:
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
if pos_key in obs:
positions_deg[f"right_{motor}"] = obs[pos_key]
for motor in follower.bus_left.motors:
vel_key = f"left_{motor}.vel"
pos_key = f"left_{motor}.pos"
if vel_key in obs:
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
if pos_key in obs:
positions_deg[f"left_{motor}"] = obs[pos_key]
# Convert velocities to rad/s and compute friction torques
friction_torques_nm = {}
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
# Extract motor name without arm prefix
if motor_full_name.startswith("right_"):
motor_name = motor_full_name.removeprefix("right_")
elif motor_full_name.startswith("left_"):
motor_name = motor_full_name.removeprefix("left_")
else:
continue
# Get motor index for friction parameters
motor_index = motor_name_to_index.get(motor_name, 0)
# Convert velocity to rad/s
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
# Compute friction torque
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
friction_torques_nm[motor_full_name] = friction_torque
# Apply friction compensation to right arm (all joints INCLUDING gripper)
for motor in follower.bus_right.motors:
full_name = f"right_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = friction_torques_nm.get(full_name, 0.0)
# Get motor index for damping gain
motor_index = motor_name_to_index.get(motor, 0)
kd = DAMPING_KD[motor_index]
# Send MIT control command with friction compensation + damping
follower.bus_right._mit_control(
motor=motor,
kp=0.0, # No position control
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Apply friction compensation to left arm (all joints INCLUDING gripper)
for motor in follower.bus_left.motors:
full_name = f"left_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = friction_torques_nm.get(full_name, 0.0)
# Get motor index for damping gain
motor_index = motor_name_to_index.get(motor, 0)
kd = DAMPING_KD[motor_index]
# Send MIT control command with friction compensation + damping
follower.bus_left._mit_control(
motor=motor,
kp=0.0, # No position control
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print status every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz")
loop_times = []
last_print_time = loop_end
time.sleep(0.001)
except KeyboardInterrupt:
print("\n\nStopping friction compensation...")
finally:
print("\nDisabling all motors and disconnecting...")
follower.bus_right.disable_torque()
follower.bus_left.disable_torque()
time.sleep(0.1)
follower.disconnect()
print("✓ Safe shutdown complete")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,142 @@
import time
import numpy as np
import pinocchio as pin
from os.path import join, dirname, exists, expanduser
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
def main() -> None:
config = OpenArmsFollowerConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0,
)
print("Initializing robot...")
follower = OpenArmsFollower(config)
follower.connect(calibrate=True)
# Load URDF for Pinocchio dynamics
urdf_path = "/home/croissant/Documents/openarm_description/openarm_bimanual_pybullet.urdf"
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
pin_robot.data = pin_robot.model.createData()
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
follower.pin_robot = pin_robot
print(f"Applying gravity compensation")
print(" 1. Support the arm before starting")
print(" 2. The arm will be held in place by gravity compensation")
print(" 3. You should be able to move it with gentle force")
print("\nPress ENTER when ready to start...")
input()
print(f"✓ Motors enabled")
print("\nStarting gravity compensation loop...")
print("Press Ctrl+C to stop\n")
loop_times = []
last_print_time = time.perf_counter()
try:
while True:
loop_start = time.perf_counter()
# Get current joint positions from robot
obs = follower.get_observation()
# Extract positions in degrees
positions_deg = {}
for motor in follower.bus_right.motors:
key = f"right_{motor}.pos"
if key in obs:
positions_deg[f"right_{motor}"] = obs[key]
for motor in follower.bus_left.motors:
key = f"left_{motor}.pos"
if key in obs:
positions_deg[f"left_{motor}"] = obs[key]
# Convert to radians and calculate gravity torques
# Use the built-in method from OpenArmsFollower
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
torques_nm = follower._gravity_from_q(positions_rad)
# Apply gravity compensation to right arm (all joints except gripper)
for motor in follower.bus_right.motors:
if motor == "gripper":
continue # Skip gripper
full_name = f"right_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = torques_nm.get(full_name, 0.0)
# Send MIT control command with gravity compensation torque
follower.bus_right._mit_control(
motor=motor,
kp=0.0, # No position control
kd=0.0, # No velocity damping
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Apply gravity compensation to left arm (all joints except gripper)
for motor in follower.bus_left.motors:
if motor == "gripper":
continue # Skip gripper
full_name = f"left_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = torques_nm.get(full_name, 0.0)
# Send MIT control command with gravity compensation torque
follower.bus_left._mit_control(
motor=motor,
kp=0.0, # No position control
kd=0.0, # No velocity damping
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print status every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
loop_times = []
last_print_time = loop_end
time.sleep(0.005)
except KeyboardInterrupt:
print("\n\nStopping gravity compensation...")
finally:
print("\nDisabling all motors and disconnecting...")
follower.bus_right.disable_torque()
follower.bus_left.disable_torque()
time.sleep(0.1)
follower.disconnect()
print("✓ Safe shutdown complete")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,395 @@
"""
OpenArms Dataset Recording with Gravity + Friction Compensation
Records a dataset using OpenArms follower robot with leader teleoperator.
Leader arms have gravity and friction compensation for weightless, easy movement.
Includes 3 cameras: left wrist, right wrist, and base camera.
Uses the same compensation approach as teleop_with_compensation.py
"""
import shutil
import time
from pathlib import Path
import numpy as np
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
# Recording parameters
NUM_EPISODES = 1
FPS = 30
EPISODE_TIME_SEC = 600
RESET_TIME_SEC = 120
TASK_DESCRIPTION = "OpenArms task description"
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
FRICTION_SCALE = 1.0
def record_loop_with_compensation(
robot,
leader,
events,
fps,
dataset,
dataset_features,
control_time_s,
single_task,
display_data=True,
):
"""
Custom record loop that applies gravity + friction compensation to leader.
Based on record_loop but with integrated compensation.
"""
dt = 1 / fps
episode_start_time = time.perf_counter()
# All joints (both arms)
all_joints = []
for motor in leader.bus_right.motors:
all_joints.append(f"right_{motor}")
for motor in leader.bus_left.motors:
all_joints.append(f"left_{motor}")
while True:
loop_start = time.perf_counter()
elapsed = loop_start - episode_start_time
# Check if we should exit
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
break
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities in degrees
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity torques for leader using built-in method
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
# Calculate friction torques for leader using built-in method
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=FRICTION_SCALE
)
# Combine gravity + friction torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower (both arms)
follower_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
# Send action to robot
if follower_action:
robot.send_action(follower_action)
# Get observation from robot (includes camera images)
observation = robot.get_observation()
# Add to dataset if we have a dataset
if dataset is not None:
# Build properly formatted observation frame
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
# Combine into single frame
frame = {**obs_frame, **action_frame}
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
frame["task"] = single_task
dataset.add_frame(frame)
# Display data if requested
if display_data:
log_rerun_data(observation=observation, action=follower_action)
# Maintain loop rate
loop_duration = time.perf_counter() - loop_start
sleep_time = dt - loop_duration
if sleep_time > 0:
time.sleep(sleep_time)
def main():
"""Main recording loop with gravity compensation."""
print("=" * 70)
print("OpenArms Dataset Recording with Compensation")
print("=" * 70)
# Create camera configurations (3 cameras: left wrist, right wrist, base)
# Using actual device paths found by lerobot-find-cameras opencv
camera_config = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
}
# Configure follower robot with cameras
follower_config = OpenArmsFollowerConfig(
port_left="can2",
port_right="can3",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
cameras=camera_config,
)
# Configure leader teleoperator (no cameras needed)
leader_config = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
# Initialize robot and teleoperator
print("\nInitializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
# Connect devices
print("Connecting and calibrating...")
follower.connect(calibrate=True)
leader.connect(calibrate=True)
# Verify URDF is loaded for gravity compensation
if leader.pin_robot is None:
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
# Configure the dataset features
# For actions, we only want to record positions (not velocity or torque)
action_features_hw = {}
for key, value in follower.action_features.items():
if key.endswith(".pos"):
action_features_hw[key] = value
action_features = hw_to_dataset_features(action_features_hw, "action")
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
print("\nCreating dataset...")
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
# Check if dataset already exists and prompt user
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
while dataset_path.exists():
print(f"\nDataset already exists at: {dataset_path}")
print("\nOptions:")
print(" 1. Overwrite existing dataset")
print(" 2. Use a different name")
print(" 3. Abort")
choice = input("\nEnter your choice (1/2/3): ").strip()
if choice == '1':
print(f"Removing existing dataset...")
shutil.rmtree(dataset_path)
print("✓ Existing dataset removed")
break
elif choice == '2':
print("\nCurrent repo_id:", repo_id)
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
if new_repo_id and '/' in new_repo_id:
repo_id = new_repo_id
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
print(f"✓ Using new repo_id: {repo_id}")
# Loop will continue if this new path also exists
else:
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
elif choice == '3':
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
follower.disconnect()
leader.disconnect()
return
else:
print("Invalid choice. Please enter 1, 2, or 3.")
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=FPS,
features=dataset_features,
robot_type=follower.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize keyboard listener and visualization
_, events = init_keyboard_listener()
init_rerun(session_name="openarms_recording")
# Enable motors on both leader arms for gravity compensation
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print("\n" + "=" * 70)
print(f"Recording {NUM_EPISODES} episodes")
print(f"Task: {TASK_DESCRIPTION}")
print("=" * 70)
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
print("\nKeyboard controls:")
print(" - Press 'q' to stop recording")
print(" - Press 'r' to re-record current episode")
print("=" * 70)
episode_idx = 0
try:
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Record episode with compensation active
record_loop_with_compensation(
robot=follower,
leader=leader,
events=events,
fps=FPS,
dataset=dataset,
dataset_features=dataset_features,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop_with_compensation(
robot=follower,
leader=leader,
events=events,
fps=FPS,
dataset=None, # Don't save reset period
dataset_features=dataset_features,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Handle re-recording
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Only save episode if frames were recorded
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
dataset.save_episode()
episode_idx += 1
else:
log_say("No frames recorded, skipping episode save")
# Clear the empty buffer
dataset.episode_buffer = None
except KeyboardInterrupt:
print("\n\nStopping recording...")
finally:
# Clean up
log_say("Stop recording")
try:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
print("✓ Shutdown complete")
except Exception as e:
print(f"Shutdown error: {e}")
# Upload dataset
print("\nUploading dataset to Hugging Face Hub...")
try:
dataset.push_to_hub()
print("✓ Dataset uploaded successfully")
except Exception as e:
print(f"Warning: Failed to upload dataset: {e}")
print("You can manually upload later using: dataset.push_to_hub()")
print("✓ Recording complete!")
if __name__ == "__main__":
main()

166
examples/openarms/replay.py Normal file
View File

@@ -0,0 +1,166 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenArms Dataset Replay Example
Replays position actions from a recorded dataset on an OpenArms follower robot.
Only position commands (ending with .pos) are replayed, not velocity or torque.
Example usage:
python examples/openarms/replay.py
"""
import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
# Configuration
EPISODE_IDX = 0
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
DATASET_ROOT = None # Use default cache location, or specify custom path
# Robot configuration - adjust these to match your setup
ROBOT_CONFIG = OpenArmsFollowerConfig(
port_left="can2", # CAN interface for left arm
port_right="can3", # CAN interface for right arm
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0, # Safety limit: max degrees to move per step
)
def main():
"""Main replay function."""
print("=" * 70)
print("OpenArms Dataset Replay")
print("=" * 70)
print(f"\nDataset: {DATASET_REPO_ID}")
print(f"Episode: {EPISODE_IDX}")
print(f"Robot: {ROBOT_CONFIG.id}")
print(f" Left arm: {ROBOT_CONFIG.port_left}")
print(f" Right arm: {ROBOT_CONFIG.port_right}")
print("\n" + "=" * 70)
# Initialize the robot
print("\n[1/3] Initializing robot...")
robot = OpenArmsFollower(ROBOT_CONFIG)
# Load the dataset
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
dataset = LeRobotDataset(
DATASET_REPO_ID,
root=DATASET_ROOT,
episodes=[EPISODE_IDX]
)
# Filter dataset to only include frames from the specified episode
# (required for dataset V3.0 where episodes are chunked)
episode_frames = dataset.hf_dataset.filter(
lambda x: x["episode_index"] == EPISODE_IDX
)
if len(episode_frames) == 0:
raise ValueError(
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
)
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
# Extract action features from dataset
action_features = dataset.features.get(ACTION, {})
action_names = action_features.get("names", [])
# Filter to only position actions (ending with .pos)
position_action_names = [name for name in action_names if name.endswith(".pos")]
if not position_action_names:
raise ValueError(
f"No position actions found in dataset. Action names: {action_names}"
)
print(f" Found {len(position_action_names)} position actions to replay")
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
# Select only action columns from dataset
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
print(f"\n[3/3] Connecting to robot...")
robot.connect(calibrate=False) # Skip calibration for replay
if not robot.is_connected:
raise RuntimeError("Robot failed to connect!")
print("\n" + "=" * 70)
print("Ready to replay!")
print("=" * 70)
print("\nThe robot will replay the recorded positions.")
print("Press Ctrl+C to stop at any time.\n")
input("Press ENTER to start replaying...")
# Replay loop
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
try:
for idx in range(len(episode_frames)):
loop_start = time.perf_counter()
# Extract action array from dataset
action_array = actions[idx][ACTION]
# Build action dictionary, but only include position actions
action = {}
for i, name in enumerate(action_names):
# Only include position actions (ending with .pos)
if name.endswith(".pos"):
action[name] = float(action_array[i])
# Send action to robot
robot.send_action(action)
# Maintain replay rate (use dataset fps)
loop_duration = time.perf_counter() - loop_start
dt_s = 1.0 / dataset.fps - loop_duration
busy_wait(dt_s)
# Progress indicator every 100 frames
if (idx + 1) % 100 == 0:
progress = (idx + 1) / len(episode_frames) * 100
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
log_say("Replay complete", blocking=True)
except KeyboardInterrupt:
print("\n\nReplay interrupted by user")
finally:
# Disconnect robot
print("\nDisconnecting robot...")
robot.disconnect()
print("✓ Replay complete!")
if __name__ == "__main__":
main()

73
examples/openarms/setup_can.sh Executable file
View File

@@ -0,0 +1,73 @@
#!/bin/bash
# Setup all OpenArms CAN interfaces with CAN FD
set -e
echo "=========================================="
echo "OpenArms CAN FD Interface Setup"
echo "=========================================="
echo ""
echo "Mode: CAN FD"
echo " - Nominal bitrate: 1 Mbps"
echo " - Data bitrate: 5 Mbps"
echo ""
echo "Configuring interfaces can0, can1, can2, can3..."
echo ""
# Configure each CAN interface with CAN FD
for i in 0 1 2 3; do
interface="can$i"
# Check if interface exists
if ! ip link show "$interface" &> /dev/null; then
echo "$interface: Not found, skipping"
continue
fi
# Bring down interface
sudo ip link set "$interface" down 2>/dev/null
# Configure CAN FD mode
sudo ip link set "$interface" type can \
bitrate 1000000 \
dbitrate 5000000 \
fd on
# Bring up interface
sudo ip link set "$interface" up
# Verify configuration
if ip link show "$interface" | grep -q "UP"; then
echo "$interface: Configured and UP"
else
echo "$interface: Failed to bring UP"
fi
done
echo ""
echo "=========================================="
echo "Verification"
echo "=========================================="
echo ""
# Show detailed status for each interface
for i in 0 1 2 3; do
interface="can$i"
if ip link show "$interface" &> /dev/null; then
echo "$interface:"
# Show key parameters
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
echo ""
fi
done
echo "=========================================="
echo "Setup Complete!"
echo "=========================================="
echo ""
echo "All interfaces configured for CAN FD mode"
echo ""
echo "Next steps:"
echo " 1. Test motors: python debug_can_communication.py"
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
echo ""

148
examples/openarms/teleop.py Normal file
View File

@@ -0,0 +1,148 @@
"""
OpenArms Teleoperation Example - Full Dual Arms
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
It first calibrates both devices, then enters a teleoperation loop for both arms.
"""
import time
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
follower_config = OpenArmsFollowerConfig(
port_left="can2", # CAN interface for follower left arm
port_right="can3", # CAN interface for follower right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0, # Safety limit
)
leader_config = OpenArmsLeaderConfig(
port_left="can0", # CAN interface for leader left arm
port_right="can1", # CAN interface for leader right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_leader",
manual_control=True, # Enable manual control (torque disabled)
)
print("=" * 60)
print("OpenArms Teleoperation - Full Dual Arms")
print("=" * 60)
# Initialize devices
print("\n[1/4] Initializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
# Connect and calibrate follower
print("\n[2/4] Connecting and calibrating follower robot...")
print("Note: If you have existing calibration, just press ENTER to use it.")
follower.connect(calibrate=True)
# Connect and calibrate leader
print("\n[3/4] Connecting and calibrating leader arm...")
print("Note: The leader arm will have torque disabled for manual control.")
leader.connect(calibrate=True)
# Wait for user to be ready
print("\n[4/4] Ready for teleoperation!")
print("\nBoth arms will be controlled (16 motors total):")
print(" RIGHT ARM: joints 1-7 + gripper")
print(" LEFT ARM: joints 1-7 + gripper")
print("\nPress ENTER to start teleoperation...")
input()
print("\nTeleoperation started! Move both leader arms.")
print("Press Ctrl+C to stop.\n")
# All joints for both arms (16 motors total)
all_joints = [
# Right arm
"right_joint_1",
"right_joint_2",
"right_joint_3",
"right_joint_4",
"right_joint_5",
"right_joint_6",
"right_joint_7",
"right_gripper",
# Left arm
"left_joint_1",
"left_joint_2",
"left_joint_3",
"left_joint_4",
"left_joint_5",
"left_joint_6",
"left_joint_7",
"left_gripper",
]
# Performance monitoring
loop_times = []
start_time = time.perf_counter()
last_print_time = start_time
try:
while True:
loop_start = time.perf_counter()
# Get action from leader
leader_action = leader.get_action()
# Filter to only position data for all joints (both arms)
joint_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
joint_action[pos_key] = leader_action[pos_key]
# Send action to follower (both arms)
if joint_action:
follower.send_action(joint_action)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print stats every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
min_time = min(loop_times)
max_time = max(loop_times)
max_hz = 1.0 / min_time if min_time > 0 else 0
min_hz = 1.0 / max_time if max_time > 0 else 0
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
f"Avg loop time: {avg_time*1000:.1f} ms")
# Reset for next measurement window
loop_times = []
last_print_time = loop_end
except KeyboardInterrupt:
print("\n\nStopping teleoperation...")
finally:
# Disconnect devices
print("Disconnecting devices...")
try:
follower.disconnect()
except Exception as e:
print(f"Error disconnecting follower: {e}")
try:
leader.disconnect()
except Exception as e:
print(f"Error disconnecting leader: {e}")
print("Done!")

View File

@@ -0,0 +1,197 @@
"""
OpenArms Mini Teleoperation Example
This script demonstrates teleoperation of an OpenArms follower robot using
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
The OpenArms Mini has:
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
Note on gripper normalization:
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
- OpenArms follower gripper: degrees (0=closed, -65=open)
- This script automatically converts between the two ranges
"""
import time
import os
import sys
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
from lerobot.utils.robot_utils import busy_wait
# Target control frequency
TARGET_FPS = 30
# Configure the OpenArms follower (Damiao motors on CAN bus)
follower_config = OpenArmsFollowerConfig(
port_left="can0", # CAN interface for follower left arm
port_right="can1", # CAN interface for follower right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0, # Safety limit (degrees per step)
)
# Configure the OpenArms Mini leader (Feetech motors on serial)
leader_config = OpenArmsMiniConfig(
port_right="/dev/ttyACM0", # Serial port for right arm
port_left="/dev/ttyACM1", # Serial port for left arm
id="openarms_mini",
use_degrees=True,
)
print("OpenArms Mini → OpenArms Follower Teleoperation")
# Initialize devices
follower = OpenArmsFollower(follower_config)
leader = OpenArmsMini(leader_config)
# Connect and calibrate follower
print("Note: If you have existing calibration, just press ENTER to use it.")
follower.connect(calibrate=True)
# Connect and calibrate leader
print("Note: The leader arms will have torque disabled for manual control.")
leader.connect(calibrate=True)
print("\nPress ENTER to start teleoperation...")
input()
print("Press Ctrl+C to stop.\n")
# All joints for both arms (16 motors total)
all_joints = [
# Right arm
"right_joint_1",
"right_joint_2",
"right_joint_3",
"right_joint_4",
"right_joint_5",
"right_joint_6",
"right_joint_7",
"right_gripper",
# Left arm
"left_joint_1",
"left_joint_2",
"left_joint_3",
"left_joint_4",
"left_joint_5",
"left_joint_6",
"left_joint_7",
"left_gripper",
]
# Performance monitoring
loop_times = []
avg_loop_time = 0.0
min_loop_time = float('inf')
max_loop_time = 0.0
stats_update_interval = 1.0 # Update stats every 1 second
last_stats_update = time.perf_counter()
SWAPPED_JOINTS = {
"right_joint_6": "right_joint_7",
"right_joint_7": "right_joint_6",
"left_joint_6": "left_joint_7",
"left_joint_7": "left_joint_6",
}
try:
while True:
loop_start = time.perf_counter()
# Get actions and observations
leader_action = leader.get_action()
follower_obs = follower.get_observation()
joint_action = {}
for joint in all_joints:
leader_key = f"{joint}.pos"
# Determine which follower joint this leader joint controls
follower_joint = SWAPPED_JOINTS.get(joint, joint)
follower_key = f"{follower_joint}.pos"
# Get leader position (default 0 if missing)
pos = leader_action.get(leader_key, 0.0)
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
if "gripper" in joint:
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
# 0 (closed) -> 0°, 100 (open) -> -65°
pos = (pos / 100.0) * -65.0
# Store in action dict for follower
joint_action[follower_key] = pos
follower.send_action(joint_action)
# Loop timing
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Update stats periodically
current_time = time.perf_counter()
if current_time - last_stats_update >= stats_update_interval:
if loop_times:
avg_loop_time = sum(loop_times) / len(loop_times)
min_loop_time = min(loop_times)
max_loop_time = max(loop_times)
loop_times = []
last_stats_update = current_time
# Display everything
sys.stdout.write("\033[H\033[J") # Clear screen
# Show timing stats at the top
if avg_loop_time > 0:
avg_hz = 1.0 / avg_loop_time
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
else:
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
# Show joint positions
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
print("-" * 52)
for joint in all_joints:
leader_key = f"{joint}.pos"
follower_joint = SWAPPED_JOINTS.get(joint, joint)
follower_key = f"{follower_joint}.pos"
leader_pos = leader_action.get(leader_key, 0.0)
follower_pos = follower_obs.get(follower_key, 0.0)
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
# Smart sleep to maintain target FPS
dt_s = time.perf_counter() - loop_start
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
except KeyboardInterrupt:
print("\n\nStopping teleoperation...")
finally:
# Disconnect devices
print("Disconnecting devices...")
try:
follower.disconnect()
except Exception as e:
print(f"Error disconnecting follower: {e}")
try:
leader.disconnect()
except Exception as e:
print(f"Error disconnecting leader: {e}")
print("Done!")

View File

@@ -0,0 +1,202 @@
"""
OpenArms Teleoperation with Gravity + Friction Compensation
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
Follower arms (both LEFT and RIGHT): Mirror leader movements
Uses the URDF file from the lerobot repository.
"""
import time
import numpy as np
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
FRICTION_SCALE = 1.0
def main():
"""Main teleoperation loop with gravity compensation"""
print("=" * 70)
print("OpenArms Teleoperation with Gravity Compensation")
print("=" * 70)
# Configuration
follower_config = OpenArmsFollowerConfig(
port_left="can2",
port_right="can3",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
)
leader_config = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
# Initialize and connect
print("\nInitializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
follower.connect()
leader.connect()
# URDF is automatically loaded in the leader constructor
if leader.pin_robot is None:
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
print("Press ENTER to start...")
input()
# Enable motors on both leader arms for gravity compensation
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print("Press Ctrl+C to stop\n")
# Main control loop
loop_times = []
last_print_time = time.perf_counter()
# All joints (both arms)
all_joints = []
for motor in leader.bus_right.motors:
all_joints.append(f"right_{motor}")
for motor in leader.bus_left.motors:
all_joints.append(f"left_{motor}")
try:
while True:
loop_start = time.perf_counter()
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities in degrees
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity torques for leader using built-in method
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
# Calculate friction torques for leader using built-in method
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=FRICTION_SCALE
)
# Combine gravity + friction torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower (both arms)
follower_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
if follower_action:
follower.send_action(follower_action)
# Performance monitoring
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
loop_times = []
last_print_time = loop_end
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
try:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
print("✓ Shutdown complete")
except Exception as e:
print(f"Shutdown error: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,152 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unify all tasks in a dataset to a single task (modifies in-place).
This script:
1. Loads a dataset
2. Sets all task_index to 0 and task description to "fold"
3. Updates tasks.parquet and task_index in data files (in-place, no copying)
Usage:
python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
write_info,
write_tasks,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
# Single unified task
UNIFIED_TASK = "fold"
def unify_dataset_tasks(
repo_id: str,
root: Path | None = None,
push_to_hub: bool = False,
) -> None:
"""Unify all tasks in a dataset to a single task (modifies in-place).
Args:
repo_id: Dataset repository ID.
root: Optional root path for dataset.
push_to_hub: Whether to push the result to HuggingFace Hub.
"""
input_root = root if root else HF_LEROBOT_HOME / repo_id
input_repo_id = repo_id
logging.info(f"Loading metadata from {repo_id}")
# Load source metadata
src_meta = LeRobotDatasetMetadata(repo_id, root=input_root)
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
logging.info(f"Original tasks: {len(src_meta.tasks)}")
# Modify in-place (input_root == output_root supported)
data_dir = input_root / DATA_DIR
# Process data files - set all task_index to 0
logging.info("Processing data files (in-place)...")
for parquet_file in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Processing data"):
df = pd.read_parquet(parquet_file)
df["task_index"] = 0 # All tasks unified to index 0
df.to_parquet(parquet_file)
# Process episodes metadata - set all tasks to unified task
logging.info("Processing episodes metadata (in-place)...")
episodes_dir = input_root / "meta" / "episodes"
if episodes_dir.exists():
for parquet_file in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
df = pd.read_parquet(parquet_file)
df["tasks"] = [[UNIFIED_TASK]] * len(df) # All episodes get the unified task
df.to_parquet(parquet_file)
else:
logging.warning(f"No episodes directory found at {episodes_dir}, skipping")
# Update tasks.parquet with single task
logging.info(f"Creating single task: {UNIFIED_TASK}")
new_tasks = pd.DataFrame({"task_index": [0]}, index=[UNIFIED_TASK])
write_tasks(new_tasks, input_root)
# Update info.json
new_info = src_meta.info.copy()
new_info["total_tasks"] = 1
write_info(new_info, input_root)
logging.info(f"Dataset modified in-place at {input_root}")
logging.info(f"Task: {UNIFIED_TASK}")
if push_to_hub:
from lerobot.datasets.lerobot_dataset import LeRobotDataset
logging.info(f"Pushing {input_repo_id} to hub")
dataset = LeRobotDataset(input_repo_id, root=input_root)
dataset.push_to_hub(private=True)
logging.info("Push complete!")
def main():
parser = argparse.ArgumentParser(
description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)."
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Dataset repository ID",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push result to HuggingFace Hub",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
unify_dataset_tasks(
repo_id=args.repo_id,
root=args.root,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,745 @@
body {
margin: 0;
padding: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: #f5f5f5;
}
main {
min-height: 100vh;
padding: 2rem;
}
header {
text-align: center;
margin-bottom: 2rem;
}
h1 {
font-size: 2rem;
font-weight: 600;
color: #333;
margin: 0;
}
h2 {
font-size: 1.25rem;
font-weight: 600;
color: #333;
margin: 0 0 1rem 0;
}
h3 {
font-size: 0.875rem;
font-weight: 600;
color: #666;
margin: 0 0 0.5rem 0;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.container {
max-width: 1920px;
margin: 0 auto;
display: grid;
grid-template-columns: minmax(500px, 600px) 1fr;
gap: 2rem;
align-items: start;
}
/* Left column container */
.left-column {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
/* Right column container */
.right-column {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
/* Responsive: Stack on smaller screens */
@media (max-width: 1200px) {
.container {
grid-template-columns: 1fr;
}
}
.panel {
background: white;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.config-panel {
border: 2px solid #e5e7eb;
}
.config-header {
display: flex;
justify-content: space-between;
align-items: center;
cursor: pointer;
user-select: none;
padding: 0.5rem 0;
}
.config-header:hover {
opacity: 0.7;
}
.toggle-icon {
font-size: 1rem;
color: #6b7280;
transition: transform 0.2s;
}
.config-content {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.robot-setup {
margin-bottom: 0.5rem;
}
.robot-status {
display: flex;
align-items: center;
justify-content: space-between;
padding: 1rem;
border-radius: 6px;
font-weight: 500;
gap: 1rem;
}
.robot-status.ready {
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
color: #065f46;
border: 1px solid #10b981;
}
.robot-status.not-ready {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
color: #92400e;
border: 1px solid #f59e0b;
}
.btn-setup {
background: #10b981;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-setup:hover:not(:disabled) {
background: #059669;
}
.btn-setup:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-zero {
background: #8b5cf6;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-zero:hover:not(:disabled) {
background: #7c3aed;
}
.btn-zero:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.zero-position-section {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.btn-zero-large {
width: 100%;
background: #8b5cf6;
color: white;
border: none;
padding: 0.875rem 1.5rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
}
.btn-zero-large:hover:not(:disabled) {
background: #7c3aed;
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
transform: translateY(-1px);
}
.btn-zero-large:disabled {
background: #d1d5db;
cursor: not-allowed;
box-shadow: none;
transform: none;
}
.delete-episode-section {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.btn-delete {
width: 100%;
background: #ef4444;
color: white;
border: none;
padding: 0.875rem 1.5rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
}
.btn-delete:hover:not(:disabled) {
background: #dc2626;
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
transform: translateY(-1px);
}
.btn-delete:disabled {
background: #d1d5db;
cursor: not-allowed;
box-shadow: none;
transform: none;
}
.delete-info {
margin-top: 0.5rem;
font-size: 0.875rem;
color: #666;
text-align: center;
font-style: italic;
}
.btn-disconnect {
background: #ef4444;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-disconnect:hover {
background: #dc2626;
}
.btn-refresh {
background: #3b82f6;
color: white;
border: none;
padding: 0.4rem 0.8rem;
border-radius: 4px;
font-size: 0.75rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-refresh:hover:not(:disabled) {
background: #2563eb;
}
.btn-refresh:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.control-panel {
border: 2px solid #10b981;
}
.status-banner {
display: flex;
align-items: center;
gap: 1rem;
padding: 1rem 1.5rem;
border-radius: 6px;
margin-bottom: 1.5rem;
font-weight: 500;
font-size: 0.95rem;
}
.status-banner.initializing {
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
color: #1e40af;
border-left: 4px solid #3b82f6;
}
.status-banner.encoding {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
color: #92400e;
border-left: 4px solid #f59e0b;
}
.status-banner.uploading {
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
color: #3730a3;
border-left: 4px solid #6366f1;
}
.status-banner.success {
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
color: #065f46;
border-left: 4px solid #10b981;
}
.status-banner.warning {
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
color: #991b1b;
border-left: 4px solid #ef4444;
}
.spinner {
width: 20px;
height: 20px;
border: 3px solid rgba(0, 0, 0, 0.1);
border-top-color: currentColor;
border-radius: 50%;
animation: spin 0.8s linear infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.control-horizontal {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.control-left {
display: flex;
flex-direction: column;
gap: 1rem;
}
.control-right {
display: flex;
align-items: center;
justify-content: center;
}
.input-group {
display: flex;
gap: 0.5rem;
margin-bottom: 0;
}
input[type="text"] {
flex: 1;
padding: 0.75rem;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 1rem;
}
input[type="text"]:disabled {
background: #f5f5f5;
cursor: not-allowed;
}
input[type="text"]:focus {
outline: none;
border-color: #10b981;
}
button {
padding: 0.75rem 1.5rem;
border: none;
border-radius: 4px;
font-size: 1rem;
font-weight: 500;
cursor: pointer;
transition: all 0.2s;
}
.btn-set-task {
background: #3b82f6;
color: white;
min-width: 120px;
}
.btn-set-task:hover:not(:disabled) {
background: #2563eb;
}
.btn-set-task:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-start {
background: #10b981;
color: white;
}
.btn-start:hover:not(:disabled) {
background: #059669;
}
.btn-start:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-stop {
background: #ef4444;
color: white;
}
.btn-stop:hover {
background: #dc2626;
}
.btn-reset {
padding: 0.5rem 1rem;
background: #6b7280;
color: white;
font-size: 0.875rem;
}
.btn-reset:hover {
background: #4b5563;
}
.status {
display: flex;
align-items: center;
gap: 0.75rem;
padding: 1rem;
border-radius: 4px;
margin-bottom: 1rem;
}
.status.recording {
background: #fee2e2;
color: #991b1b;
}
.status.recording.recording-active {
display: flex;
flex-direction: column;
gap: 1rem;
background: #dc2626;
color: white;
padding: 1.5rem;
border: 4px solid #991b1b;
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
font-weight: 700;
font-size: 1rem;
}
.status.recording.recording-active .indicator {
width: 20px;
height: 20px;
background: #fef2f2;
animation: pulse-strong 1s ease-in-out infinite;
}
@keyframes pulse-strong {
0%, 100% {
opacity: 1;
transform: scale(1);
}
50% {
opacity: 0.7;
transform: scale(1.1);
}
}
.status.recording.recording-active .time-display {
display: flex;
flex-direction: column;
gap: 0.5rem;
font-size: 1.5rem;
font-weight: 700;
color: white;
}
.fps-display {
font-size: 1rem;
font-weight: 500;
opacity: 0.95;
}
.fps-warning {
color: #fef2f2;
animation: pulse-warning 1s ease-in-out infinite;
}
@keyframes pulse-warning {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.status.recording.recording-active .btn-stop {
align-self: stretch;
}
.ramp-up-countdown {
display: flex;
justify-content: center;
margin-bottom: 1rem;
}
.countdown-box {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 2rem 3rem;
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
border: 4px solid #f59e0b;
border-radius: 16px;
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
min-width: 280px;
animation: pulse-warm 1.5s ease-in-out infinite;
}
@keyframes pulse-warm {
0%, 100% {
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
}
50% {
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
}
}
.countdown-label {
font-size: 1rem;
color: #92400e;
text-transform: uppercase;
letter-spacing: 1.5px;
font-weight: 800;
margin-bottom: 1rem;
text-align: center;
}
.countdown-value {
font-size: 4.5rem;
font-weight: 900;
color: #d97706;
font-family: 'Courier New', monospace;
line-height: 1;
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
margin-bottom: 0.5rem;
}
.countdown-subtitle {
font-size: 0.875rem;
color: #78350f;
font-weight: 600;
font-style: italic;
text-align: center;
margin-top: 0.5rem;
}
.status.idle {
background: #f3f4f6;
color: #374151;
}
.indicator {
width: 12px;
height: 12px;
border-radius: 50%;
background: #ef4444;
animation: pulse 1.5s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.counter {
display: flex;
flex-direction: column;
align-items: center;
gap: 0.75rem;
padding: 1.5rem;
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
border-radius: 8px;
border: 2px solid #e5e7eb;
min-width: 200px;
}
.counter-label {
font-size: 0.75rem;
color: #6b7280;
text-transform: uppercase;
letter-spacing: 0.5px;
font-weight: 600;
}
.counter-value {
font-size: 3rem;
font-weight: 700;
color: #10b981;
line-height: 1;
}
.time-display {
font-size: 1.5rem;
font-weight: 600;
font-family: 'Courier New', monospace;
}
.error-box {
padding: 1rem;
background: #fee2e2;
color: #991b1b;
border-radius: 4px;
border-left: 4px solid #ef4444;
font-size: 0.875rem;
}
.config-section {
margin-bottom: 1.5rem;
}
.config-section:last-child {
margin-bottom: 0;
}
.config-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 1rem;
}
label {
display: flex;
flex-direction: column;
gap: 0.5rem;
font-size: 0.875rem;
color: #374151;
font-weight: 500;
}
select {
padding: 0.5rem;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 0.875rem;
background: white;
}
select:disabled {
background: #f5f5f5;
cursor: not-allowed;
}
/* Camera Layout */
.camera-layout {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.camera-base {
width: 100%;
}
.camera-wrist-container {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 1.5rem;
}
.camera-wrist {
width: 100%;
}
.camera {
border: 1px solid #e5e7eb;
border-radius: 4px;
overflow: hidden;
}
.camera h3 {
padding: 0.75rem;
background: #f9fafb;
border-bottom: 1px solid #e5e7eb;
margin: 0;
}
.camera img {
width: 100%;
height: auto;
display: block;
background: #000;
min-height: 300px;
object-fit: cover;
}
.camera-placeholder {
text-align: center;
padding: 4rem 2rem;
background: #f9fafb;
border-radius: 4px;
border: 2px dashed #d1d5db;
}
.camera-placeholder p {
margin: 0.5rem 0;
font-size: 1rem;
color: #6b7280;
}
.camera-placeholder p:first-child {
font-size: 1.25rem;
font-weight: 500;
color: #374151;
}
.hint {
margin-top: 0.5rem;
font-size: 0.75rem;
color: #6b7280;
display: flex;
align-items: center;
gap: 0.5rem;
flex-wrap: wrap;
}

View File

@@ -0,0 +1,857 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import './App.css';
const API_BASE = 'http://localhost:8000/api';
function App() {
// State
const [task, setTask] = useState('');
const [isRecording, setIsRecording] = useState(false);
const [isInitializing, setIsInitializing] = useState(false);
const [isEncoding, setIsEncoding] = useState(false);
const [isUploading, setIsUploading] = useState(false);
const [robotsReady, setRobotsReady] = useState(false);
const [elapsedTime, setElapsedTime] = useState(0);
const [currentFps, setCurrentFps] = useState(0);
const [loopFps, setLoopFps] = useState(0);
const [episodeCount, setEpisodeCount] = useState(0);
const [error, setError] = useState(null);
const [statusMessage, setStatusMessage] = useState('Ready');
const [uploadStatus, setUploadStatus] = useState(null);
const [rampUpRemaining, setRampUpRemaining] = useState(0);
const [movingToZero, setMovingToZero] = useState(false);
const [configExpanded, setConfigExpanded] = useState(false);
const [latestRepoId, setLatestRepoId] = useState(null);
// Configuration
const [config, setConfig] = useState({
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
leader_left: 'can0',
leader_right: 'can1',
follower_left: 'can2',
follower_right: 'can3',
left_wrist: '/dev/video0',
right_wrist: '/dev/video1',
base: '/dev/video4'
});
// Available options
const [availableCameras, setAvailableCameras] = useState([]);
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
const statusIntervalRef = useRef(null);
const hasInitializedRef = useRef(false);
const loadConfig = () => {
try {
const saved = localStorage.getItem('openarms_config');
if (saved) {
const loadedConfig = JSON.parse(saved);
setConfig(prev => ({ ...prev, ...loadedConfig }));
}
} catch (e) {
console.error('Load config error:', e);
}
};
const saveConfig = (newConfig) => {
try {
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
} catch (e) {
console.error('Save config error:', e);
}
};
// Fetch status periodically
const fetchStatus = async () => {
try {
const response = await fetch(`${API_BASE}/status`);
const data = await response.json();
setIsRecording(data.is_recording);
setIsInitializing(data.is_initializing);
setIsEncoding(data.is_encoding);
setIsUploading(data.is_uploading);
setRobotsReady(data.robots_ready);
setElapsedTime(data.elapsed_time);
setCurrentFps(data.current_fps || 0);
setLoopFps(data.loop_fps || 0);
setEpisodeCount(data.episode_count);
setError(data.error);
setStatusMessage(data.status_message || 'Ready');
setUploadStatus(data.upload_status);
setRampUpRemaining(data.ramp_up_remaining || 0);
setMovingToZero(data.moving_to_zero || false);
// Track the latest repo_id from the backend
if (data.latest_repo_id) {
setLatestRepoId(data.latest_repo_id);
}
if (data.config) {
// Only merge server config if we don't have a saved config (first load)
if (!localStorage.getItem('openarms_config')) {
setConfig(prev => {
const merged = { ...data.config, ...prev };
localStorage.setItem('openarms_config', JSON.stringify(merged));
return merged;
});
}
}
} catch (e) {
console.error('Failed to fetch status:', e);
}
};
const setupRobots = async () => {
// Show warning to verify camera positions
const confirmed = window.confirm(
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
'📹 Check that cameras are correctly positioned:\n' +
' • LEFT wrist camera is actually on the LEFT arm\n' +
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
' • BASE camera is actually the BASE/overhead camera\n\n' +
'Incorrect camera positioning will result in invalid training data!\n\n' +
'Click OK to continue with robot setup, or Cancel to review configuration.'
);
if (!confirmed) {
return; // User cancelled, don't proceed
}
setError(null);
try {
const response = await fetch(`${API_BASE}/robots/setup`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(config)
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to setup robots');
}
await response.json();
saveConfig(config);
} catch (e) {
setError(`Robot setup failed: ${e.message}`);
}
};
// Disconnect robots
const disconnectRobots = async () => {
try {
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
setRobotsReady(false);
} catch (e) {
console.error('Failed to disconnect robots:', e);
}
};
// Discover cameras
const discoverCameras = async () => {
try {
const response = await fetch(`${API_BASE}/cameras/discover`);
const data = await response.json();
const cameras = data.cameras || [];
setAvailableCameras(cameras);
// Get list of valid camera IDs
const validCameraIds = cameras.map(cam => String(cam.id));
// Auto-fix config if current values are invalid or not set
const updated = { ...config };
let changed = false;
// Auto-fix invalid camera config
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
if (cameras.length >= 1) {
updated.left_wrist = String(cameras[0].id);
changed = true;
}
}
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
if (cameras.length >= 2) {
updated.right_wrist = String(cameras[1].id);
changed = true;
}
}
if (!config.base || !validCameraIds.includes(config.base)) {
if (cameras.length >= 3) {
updated.base = String(cameras[2].id);
changed = true;
}
}
if (changed) {
setConfig(updated);
saveConfig(updated);
}
if (cameras.length === 0) {
setError('No cameras detected! Please connect cameras and refresh.');
}
} catch (e) {
console.error('Failed to discover cameras:', e);
setError(`Camera discovery failed: ${e.message}`);
}
};
// Discover USB ports
const discoverUsbPorts = async () => {
try {
const response = await fetch(`${API_BASE}/usb/discover`);
const data = await response.json();
const ports = data.ports || [];
setAvailableUsbPorts(ports);
// Auto-fix config if OpenArms Mini is selected and ports are invalid
if (config.leader_type === 'openarms_mini') {
const updated = { ...config };
let changed = false;
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
updated.leader_left = ports[0];
changed = true;
}
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
updated.leader_right = ports[1];
changed = true;
}
if (changed) {
setConfig(updated);
saveConfig(updated);
}
}
if (ports.length === 0) {
console.warn('No USB ports detected for OpenArms Mini');
}
} catch (e) {
console.error('Failed to discover USB ports:', e);
}
};
// Set task only (for pedal use)
const setTaskOnly = async () => {
if (!task.trim()) {
setError('Please enter a task description');
return;
}
setError(null);
try {
const response = await fetch(`${API_BASE}/recording/set-task`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ task, ...config })
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to set task');
}
const result = await response.json();
setStatusMessage(result.message || `Task set: ${task}`);
saveConfig(config);
// Clear success message after 3 seconds
setTimeout(() => {
if (!isRecording && !isInitializing) {
setStatusMessage('Ready');
}
}, 3000);
} catch (e) {
setError(e.message);
}
};
// Start recording
const startRecording = async () => {
if (!task.trim()) {
setError('Please enter a task description');
return;
}
setError(null);
try {
const response = await fetch(`${API_BASE}/recording/start`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ task, ...config })
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to start recording');
}
await response.json();
saveConfig(config);
} catch (e) {
setError(e.message);
}
};
// Stop recording
const stopRecording = async () => {
try {
const response = await fetch(`${API_BASE}/recording/stop`, {
method: 'POST'
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to stop recording');
}
const data = await response.json();
setError(null);
// Update latest repo_id after recording
if (data.dataset_name) {
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
}
} catch (e) {
setError(e.message);
}
};
const deleteLatestEpisode = async () => {
if (!latestRepoId) {
setError('No episode to delete');
return;
}
const confirmed = window.confirm(
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
);
if (!confirmed) {
return;
}
try {
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to delete episode');
}
const data = await response.json();
setLatestRepoId(null);
setEpisodeCount(Math.max(0, episodeCount - 1));
setStatusMessage(`Deleted: ${data.deleted_repo}`);
setTimeout(() => {
if (!isRecording && !isInitializing) {
setStatusMessage('Ready');
}
}, 3000);
} catch (e) {
setError(`Delete failed: ${e.message}`);
}
};
// Reset counter
const resetCounter = async () => {
try {
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
setEpisodeCount(0);
} catch (e) {
console.error('Failed to reset counter:', e);
}
};
// Move robot to zero position
const moveToZero = async () => {
setError(null);
try {
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to move to zero position');
}
await response.json();
} catch (e) {
setError(`Move to zero failed: ${e.message}`);
}
};
// Format time as MM:SS
const formatTime = (seconds) => {
const mins = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
};
// Update config and save
const updateConfig = (key, value) => {
const updated = { ...config, [key]: value };
setConfig(updated);
saveConfig(updated);
};
// Initialize on mount only
useEffect(() => {
// Prevent double-initialization in development
if (hasInitializedRef.current) {
return;
}
hasInitializedRef.current = true;
loadConfig();
discoverCameras();
discoverUsbPorts();
fetchStatus();
statusIntervalRef.current = setInterval(fetchStatus, 1000);
return () => {
if (statusIntervalRef.current) {
clearInterval(statusIntervalRef.current);
}
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []); // Run only once on mount
// Discover USB ports when leader type changes to Mini
useEffect(() => {
if (config.leader_type === 'openarms_mini') {
discoverUsbPorts();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [config.leader_type]);
return (
<main>
<header>
<h1>OpenArms Recording</h1>
</header>
<div className="container">
{/* Left Column: Configuration and Recording Control */}
<div className="left-column">
{/* Configuration Panel */}
<section className="panel config-panel">
<div
className="config-header"
onClick={() => setConfigExpanded(!configExpanded)}
role="button"
tabIndex={0}
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
>
<h2> Configuration</h2>
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
</div>
{configExpanded && (
<div className="config-content">
{/* Robot Setup */}
<div className="config-section">
<h3>🤖 Robot Setup</h3>
<div className="robot-setup">
{robotsReady ? (
<div className="robot-status ready">
<span> Robots Ready - Recording will start instantly</span>
<button onClick={disconnectRobots} className="btn-disconnect">
Disconnect Robots
</button>
</div>
) : (
<div className="robot-status not-ready">
<span> Robots not initialized - Recording will take ~10 seconds</span>
<button
onClick={setupRobots}
disabled={isRecording || isInitializing}
className="btn-setup"
>
🚀 Setup Robots
</button>
</div>
)}
</div>
</div>
{/* Leader Type Selection */}
<div className="config-section">
<h3>🎮 Leader Type</h3>
<div className="config-grid">
<label style={{gridColumn: '1 / -1'}}>
Leader Arm Type
<select
value={config.leader_type}
onChange={(e) => updateConfig('leader_type', e.target.value)}
disabled={isRecording || robotsReady}
>
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
</select>
</label>
</div>
</div>
{/* Leader Interfaces (CAN or USB based on type) */}
<div className="config-section">
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
<h3>
{config.leader_type === 'openarms_mini'
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
: 'Leader Interfaces (CAN)'}
</h3>
{config.leader_type === 'openarms_mini' && (
<button
onClick={discoverUsbPorts}
className="btn-refresh"
disabled={isRecording || robotsReady}
>
🔄 Refresh
</button>
)}
</div>
<div className="config-grid">
<label>
Leader Left
<select
value={config.leader_left}
onChange={(e) => updateConfig('leader_left', e.target.value)}
disabled={isRecording || robotsReady}
>
{config.leader_type === 'openarms_mini' ? (
availableUsbPorts.length > 0 ? (
availableUsbPorts.map((port) => (
<option key={port} value={port}>{port}</option>
))
) : (
<option value="">No USB ports detected</option>
)
) : (
canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))
)}
</select>
</label>
<label>
Leader Right
<select
value={config.leader_right}
onChange={(e) => updateConfig('leader_right', e.target.value)}
disabled={isRecording || robotsReady}
>
{config.leader_type === 'openarms_mini' ? (
availableUsbPorts.length > 0 ? (
availableUsbPorts.map((port) => (
<option key={port} value={port}>{port}</option>
))
) : (
<option value="">No USB ports detected</option>
)
) : (
canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))
)}
</select>
</label>
</div>
</div>
{/* Follower CAN Interfaces */}
<div className="config-section">
<h3>Follower Interfaces (CAN)</h3>
<div className="config-grid">
<label>
Follower Left
<select
value={config.follower_left}
onChange={(e) => updateConfig('follower_left', e.target.value)}
disabled={isRecording || robotsReady}
>
{canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))}
</select>
</label>
<label>
Follower Right
<select
value={config.follower_right}
onChange={(e) => updateConfig('follower_right', e.target.value)}
disabled={isRecording || robotsReady}
>
{canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))}
</select>
</label>
</div>
</div>
{/* Camera Configuration */}
<div className="config-section">
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
<button
onClick={discoverCameras}
className="btn-refresh"
disabled={isRecording || robotsReady}
>
🔄 Refresh
</button>
</div>
<div className="config-grid">
<label>
Left Wrist
<select
value={config.left_wrist}
onChange={(e) => updateConfig('left_wrist', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
<label>
Right Wrist
<select
value={config.right_wrist}
onChange={(e) => updateConfig('right_wrist', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
<label>
Base Camera
<select
value={config.base}
onChange={(e) => updateConfig('base', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
</div>
</div>
</div>
)}
</section>
{/* Control Panel */}
<section className="panel control-panel">
<h2>🎬 Recording Control</h2>
{/* Status Banner - Always show important statuses */}
{isInitializing && (
<div className="status-banner initializing">
<div className="spinner"></div>
<span>{statusMessage}</span>
</div>
)}
{isEncoding && (
<div className="status-banner encoding">
<div className="spinner"></div>
<span>📹 {statusMessage}</span>
</div>
)}
{isUploading && (
<div className="status-banner uploading">
<div className="spinner"></div>
<span> {statusMessage}</span>
</div>
)}
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
<span>{uploadStatus}</span>
</div>
)}
<div className="control-horizontal">
{/* Task Input and Status */}
<div className="control-left">
<div className="input-group">
<input
type="text"
value={task}
onChange={(e) => setTask(e.target.value)}
placeholder="Task description (e.g., 'pick and place')"
disabled={isRecording || isInitializing || isEncoding || isUploading}
onKeyPress={(e) => {
if (e.key === 'Enter' && robotsReady) {
setTaskOnly();
}
}}
/>
<button
onClick={setTaskOnly}
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
className="btn-set-task"
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
>
💾 Set Task
</button>
<button
onClick={startRecording}
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
className="btn-start"
title={!robotsReady ? 'Please setup robots first' : ''}
>
{isInitializing
? '⏳ Initializing...'
: isRecording
? '⏺ Recording...'
: robotsReady
? '⏺ Start Recording'
: '⏺ Setup Robots First'}
</button>
</div>
{/* Ramp-up Countdown */}
{isRecording && rampUpRemaining > 0 && (
<div className="ramp-up-countdown">
<div className="countdown-box">
<div className="countdown-label"> WARMING UP - PID RAMP-UP</div>
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
<div className="countdown-subtitle">Recording will start automatically...</div>
</div>
</div>
)}
{/* Recording Status - Only show after ramp-up */}
{isRecording && rampUpRemaining <= 0 && (
<div className="status recording recording-active">
<div className="indicator"></div>
<div className="time-display">
<span>{formatTime(elapsedTime)}</span>
<span className="fps-display">
Loop: {loopFps.toFixed(1)} Hz
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> </span>}
</span>
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
</div>
<button onClick={stopRecording} className="btn-stop">
Stop
</button>
</div>
)}
</div>
{/* Episode Counter */}
<div className="control-right">
<div className="counter">
<div className="counter-label">Episodes Recorded</div>
<div className="counter-value">{episodeCount}</div>
<button onClick={resetCounter} className="btn-reset">
Reset
</button>
</div>
</div>
</div>
{/* Delete Latest Episode Button */}
{!isRecording && !isInitializing && latestRepoId && (
<div className="delete-episode-section">
<button
onClick={deleteLatestEpisode}
className="btn-delete"
title="Delete the latest recorded episode from HuggingFace Hub"
>
Delete Latest Episode
</button>
<div className="delete-info">Will delete: {latestRepoId}</div>
</div>
)}
{/* Move to Zero Button */}
{robotsReady && !isRecording && !isInitializing && (
<div className="zero-position-section">
<button
onClick={moveToZero}
disabled={movingToZero}
className="btn-zero-large"
title="Move both leader and follower robots to zero position (2s)"
>
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
</button>
</div>
)}
{/* Error Display */}
{error && (
<div className="error-box">
{error}
</div>
)}
</section>
</div>
{/* Right Column: Camera Feeds */}
<div className="right-column">
<section className="panel cameras">
<h2>📹 Camera Views</h2>
{robotsReady || isRecording || isInitializing ? (
<div className="camera-layout">
{/* Base camera - full width */}
<div className="camera camera-base">
<h3>Base Camera</h3>
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
</div>
{/* Wrist cameras - side by side */}
<div className="camera-wrist-container">
<div className="camera camera-wrist">
<h3>Left Wrist</h3>
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
</div>
<div className="camera camera-wrist">
<h3>Right Wrist</h3>
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
</div>
</div>
</div>
) : (
<div className="camera-placeholder">
<p>📷 Camera feeds will appear when robots are set up</p>
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
</div>
)}
</section>
</div>
</div>
</main>
);
}
export default App;

View File

@@ -0,0 +1,41 @@
# OpenArms Web Recording Interface
A web interface for recording OpenArms datasets.
## Installation
```bash
cd examples/openarms_web_interface
npm install
```
## Usage
**Start everything with one command:**
```bash
./launch.sh
```
This will:
- Start the FastAPI backend on port 8000
- Start the React frontend on port 5173
- Show live logs from both services
Then open your browser to: **http://localhost:5173**
**Stop with:** `Ctrl+C`
---
## Workflow
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
2. Click **"Setup Robots"** to initialize (once at start)
3. Enter a **task description**
4. Click **"Start Recording"** to begin an episode
5. Click **"Stop Recording"** when done
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
---

View File

@@ -0,0 +1,12 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>OpenArms Recording Interface</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/main.jsx"></script>
</body>
</html>

View File

@@ -0,0 +1,142 @@
#!/bin/bash
# OpenArms Web Interface Launcher
# Starts Rerun viewer, FastAPI backend, and React frontend
set -e
# Colors for output
GREEN='\033[0;32m'
BLUE='\033[0;34m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m' # No Color
# Get script directory
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
echo ""
# Function to cleanup on exit
cleanup() {
echo ""
echo -e "${YELLOW}Shutting down services...${NC}"
# Kill all child processes
pkill -P $$ 2>/dev/null || true
# Kill specific services by port
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
echo -e "${GREEN}✓ Services stopped${NC}"
exit 0
}
# Register cleanup on script exit
trap cleanup EXIT INT TERM
# Check if required commands exist
command -v rerun >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
exit 1
}
command -v python >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'python' not found${NC}"
exit 1
}
command -v npm >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'npm' not found${NC}"
exit 1
}
# Check if node_modules exists
if [ ! -d "node_modules" ]; then
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
npm install
echo -e "${GREEN}✓ Dependencies installed${NC}"
echo ""
fi
echo -e "${GREEN}Starting services...${NC}"
echo ""
# 1. Start FastAPI backend (Rerun will start when recording begins)
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
cd "$SCRIPT_DIR"
# Use Python from current environment (if lerobot env is active, it will use that)
# Otherwise, check if we need to use conda run
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
# Already in lerobot environment
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
PYTHON_CMD="python"
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
# lerobot env exists but not active - use conda run
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
else
# Fall back to system python
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
PYTHON_CMD="python"
fi
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
BACKEND_PID=$!
sleep 3
if ps -p $BACKEND_PID > /dev/null; then
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
else
echo -e "${RED}✗ Failed to start backend${NC}"
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
exit 1
fi
echo ""
# 2. Start React frontend
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
cd "$SCRIPT_DIR"
npm run dev > /tmp/openarms_frontend.log 2>&1 &
FRONTEND_PID=$!
sleep 3
if ps -p $FRONTEND_PID > /dev/null; then
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
else
echo -e "${RED}✗ Failed to start frontend${NC}"
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
exit 1
fi
echo ""
# Display status
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
echo ""
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
echo ""
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
echo ""
echo -e "${YELLOW}Logs:${NC}"
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
echo ""
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
echo ""
# Keep script running and wait for any service to exit
wait

View File

@@ -0,0 +1,7 @@
import { createRoot } from 'react-dom/client'
import App from './App.jsx'
createRoot(document.getElementById('root')).render(
<App />
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
{
"name": "openarms-web-interface",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview"
},
"dependencies": {
"react": "^18.3.1",
"react-dom": "^18.3.1"
},
"devDependencies": {
"@types/react": "^18.3.12",
"@types/react-dom": "^18.3.1",
"@vitejs/plugin-react": "^4.3.4",
"vite": "^6.0.1"
}
}

View File

@@ -0,0 +1,17 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
// https://vite.dev/config/
export default defineConfig({
plugins: [react()],
server: {
port: 5173,
strictPort: false,
host: true,
open: false
},
build: {
outDir: 'dist',
sourcemap: true
}
})

File diff suppressed because it is too large Load Diff

View File

@@ -142,38 +142,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -182,24 +168,41 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -149,38 +149,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_record")
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
@@ -188,25 +173,43 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -73,32 +73,34 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
# Clean up
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
# Clean up
robot.disconnect()
if __name__ == "__main__":

1
examples/rac/cmd.sh Normal file
View File

@@ -0,0 +1 @@
python examples/rac/rac_data_collection_openarms_rtc.py --robot.type=openarms_follower --robot.port_right=can1 --robot.port_left=can0 --robot.cameras="{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}" --teleop.type=openarms_mini --teleop.port_right=/dev/ttyACM0 --teleop.port_left=/dev/ttyACM1 --policy.path=lerobot-data-collection/level1_rac3_100k --dataset.repo_id=lerobot-data-collection/level1_rac3_rtc_s5_2 --dataset.single_task="Fold the T-shirt properly" --dataset.num_episodes=5

View File

@@ -0,0 +1,638 @@
#!/usr/bin/env python
"""
RaC (Recovery and Correction) Data Collection with Policy Rollout + Human Intervention.
This implements the RaC paradigm from "RaC: Robot Learning for Long-Horizon Tasks
by Scaling Recovery and Correction" (Hu et al., 2025) for LeRobot.
RaC improves upon standard data collection (BC) and prior human-in-the-loop methods
(DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors:
The workflow:
1. Policy runs autonomously
2. Press SPACE to pause - robot holds position
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Key RaC Rules:
- Rule 1 (Recover then Correct): Every intervention = recovery + correction (both human)
- Rule 2 (Terminate after Intervention): Episode ends after correction
The recovery segment (teleoperating back to good state) is recorded as training data -
this teaches the policy how to recover from errors.
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/rac_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/rac_dataset \
--dataset.single_task="Pick up the cube"
"""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from pprint import pformat
from typing import Any
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
IdentityProcessor,
PolicyAction,
PolicyProcessorPipeline,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless, predict_action
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@dataclass
class RaCDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
reset_time_s: float = 30
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
rename_map: dict[str, str] = field(default_factory=dict)
@dataclass
class RaCConfig:
robot: RobotConfig
dataset: RaCDatasetConfig
policy: PreTrainedConfig
teleop: TeleoperatorConfig
display_data: bool = True
play_sounds: bool = True
resume: bool = False
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
def init_rac_keyboard_listener():
"""Initialize keyboard listener with RaC-specific controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot
"correction_active": False, # 'c' pressed - human controlling, recording correction
"in_reset": False, # True during reset period
"start_next_episode": False, # Signal to start next episode
}
if is_headless():
logging.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
# During reset: any action key starts next episode
if key == keyboard.Key.space or key == keyboard.Key.right:
print("\n[RaC] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, 'char') and key.char == 'c':
print("\n[RaC] Starting next episode...")
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
print("[RaC] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
# During episode
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
print(" Press 'c' or START to take control")
events["policy_paused"] = True
elif hasattr(key, 'char') and key.char == 'c':
if events["policy_paused"] and not events["correction_active"]:
print("\n[RaC] ▶ START pressed - taking control")
events["start_next_episode"] = True
elif key == keyboard.Key.right:
print("[RaC] → End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("[RaC] ← Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("[RaC] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
start_pedal_listener(events)
return listener, events
def start_pedal_listener(events: dict):
"""Start foot pedal listener thread if evdev is available."""
import threading
try:
from evdev import InputDevice, ecodes
except ImportError:
logging.info("[Pedal] evdev not installed - pedal support disabled")
return
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
KEY_LEFT = "KEY_A" # Left pedal
KEY_RIGHT = "KEY_C" # Right pedal
def pedal_reader():
try:
dev = InputDevice(PEDAL_DEVICE)
print(f"[Pedal] Connected: {dev.name}")
print(f"[Pedal] Right=pause/next, Left=take control/start")
for ev in dev.read_loop():
if ev.type != ecodes.EV_KEY:
continue
from evdev import categorize
key = categorize(ev)
code = key.keycode
if isinstance(code, (list, tuple)):
code = code[0]
# Only trigger on key down
if key.keystate != 1:
continue
if events["in_reset"]:
# During reset: either pedal starts next episode
if code in [KEY_LEFT, KEY_RIGHT]:
print("\n[Pedal] Starting next episode...")
events["start_next_episode"] = True
else:
# During episode
if code == KEY_RIGHT:
# Right pedal: SPACE (pause) when running, → (next) when in correction
if events["correction_active"]:
print("\n[Pedal] → End episode")
events["exit_early"] = True
elif not events["policy_paused"]:
print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot")
print(" Press left pedal to take control")
events["policy_paused"] = True
elif code == KEY_LEFT:
# Left pedal: START (take control) when paused
if events["policy_paused"] and not events["correction_active"]:
print("\n[Pedal] ▶ START pressed - taking control")
events["start_next_episode"] = True
except FileNotFoundError:
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
except PermissionError:
logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}")
except Exception as e:
logging.debug(f"[Pedal] Error: {e}")
thread = threading.Thread(target=pedal_reader, daemon=True)
thread.start()
def make_identity_processors():
"""Create identity processors for RaC recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessor()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessor()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessor()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, robot_proc, obs_proc
def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move all robot joints to zero position."""
obs = robot.get_observation()
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
target_pos = {k: 0.0 for k in current_pos}
print(f"[RaC] Moving robot to zero position ({duration_s}s)...")
steps = int(duration_s * fps)
for step in range(steps + 1):
t = step / steps
interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos}
robot.send_action(interp_pos)
time.sleep(1 / fps)
print("[RaC] Robot at zero position.")
@safe_stop_image_writer
def rac_rollout_loop(
robot: Robot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
dataset: LeRobotDataset,
events: dict,
fps: int,
control_time_s: float,
single_task: str,
display_data: bool = True,
) -> dict:
"""
RaC rollout loop with two-stage intervention:
1. Policy runs autonomously (recording)
2. SPACE: Policy pauses (NOT recording) - robot holds position
3. 'c': Human takes control (recording correction)
4. →: End episode
"""
policy.reset()
preprocessor.reset()
postprocessor.reset()
device = get_safe_torch_device(policy.config.device)
frame_buffer = []
stats = {
"total_frames": 0,
"autonomous_frames": 0,
"paused_frames": 0,
"correction_frames": 0,
}
last_robot_action = None
was_paused = False
was_correction_active = False
waiting_for_takeover = False
timestamp = 0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Detect transition to paused state
if events["policy_paused"] and not was_paused:
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
print("[RaC] Moving teleop to robot position (2s smooth transition)...")
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
print("[RaC] Teleop aligned. Press START to take control.")
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
# Wait for start button before enabling correction mode
if waiting_for_takeover and events["start_next_episode"]:
print("[RaC] Start pressed - enabling teleop control...")
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
was_correction_active = True
obs = robot.get_observation()
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
if events["correction_active"]:
# Human controlling - record correction data
robot_action = teleop.get_action()
robot.send_action(robot_action)
stats["correction_frames"] += 1
# Record this frame
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task}
frame_buffer.append(frame)
stats["total_frames"] += 1
elif waiting_for_takeover:
# Waiting for START - policy stopped, no recording, robot holds position
if last_robot_action is not None:
robot.send_action(last_robot_action)
stats["paused_frames"] += 1
elif events["policy_paused"]:
# Paused and user acknowledged - hold last position, don't record
if last_robot_action is not None:
robot.send_action(last_robot_action)
stats["paused_frames"] += 1
robot_action = last_robot_action
else:
# Normal policy execution - record
action_values = predict_action(
observation=obs_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
robot.send_action(robot_action)
last_robot_action = robot_action
stats["autonomous_frames"] += 1
# Record this frame
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task}
frame_buffer.append(frame)
stats["total_frames"] += 1
if display_data and robot_action is not None:
log_rerun_data(observation=obs, action=robot_action)
dt = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt)
timestamp = time.perf_counter() - start_t
for frame in frame_buffer:
dataset.add_frame(frame)
return stats
def reset_loop(
robot: Robot,
teleop: Teleoperator,
events: dict,
fps: int,
):
"""Reset period where human repositions environment. Two-stage: enable teleop, then start episode."""
print("\n" + "=" * 65)
print(" [RaC] RESET - Moving teleop to robot position...")
print("=" * 65)
# Enter reset mode
events["in_reset"] = True
events["start_next_episode"] = False
# Move teleop to match robot position to avoid sudden jumps
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
# Stage 1: Wait for user to press start to enable teleoperation
print(" Teleop aligned. Press any key/pedal to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
# Stage 2: Enable teleop and let user move robot to starting position
events["start_next_episode"] = False
teleop.disable_torque()
print(" Teleop enabled - move robot to starting position")
print(" Press any key/pedal to start next episode")
# Wait for user to signal ready for next episode
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
dt = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt)
# Exit reset mode and clear flags for next episode
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
@parser.wrap()
def rac_collect(cfg: RaCConfig) -> LeRobotDataset:
"""Main RaC data collection function."""
init_logging()
logging.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="rac_collection")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, robot_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and robot.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
teleop.connect()
listener, events = init_rac_keyboard_listener()
print("\n" + "=" * 65)
print(" RaC (Recovery and Correction) Data Collection")
print("=" * 65)
print(" Policy runs autonomously until you intervene.")
print()
print(" Controls:")
print(" SPACE - Pause policy (robot holds position, no recording)")
print(" c - Take control (start correction, recording)")
print(" → - End episode (save)")
print(" ← - Re-record episode")
print(" ESC - Stop session and push to hub")
print("=" * 65 + "\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps)
stats = rac_rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
fps=cfg.dataset.fps,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
logging.info(f"Episode stats: {stats}")
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
# Reset between episodes
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(
robot=robot,
teleop=teleop,
events=events,
fps=cfg.dataset.fps,
)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
rac_collect()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,659 @@
#!/usr/bin/env python
"""
RaC (Recovery and Correction) Data Collection for OpenArms Robot.
This implements the RaC paradigm from "RaC: Robot Learning for Long-Horizon Tasks
by Scaling Recovery and Correction" (Hu et al., 2025) for LeRobot with OpenArms.
RaC improves upon standard data collection (BC) and prior human-in-the-loop methods
(DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors:
The workflow:
1. Policy runs autonomously (teleop is idle/free)
2. Press SPACE to pause - teleop moves to match robot position
3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Key RaC Rules:
- Rule 1 (Recover then Correct): Every intervention = recovery + correction (both human)
- Rule 2 (Terminate after Intervention): Episode ends after correction
The recovery segment (teleoperating back to good state) is recorded as training data -
this teaches the policy how to recover from errors.
Keyboard Controls:
SPACE - Pause policy (teleop mirrors robot, no recording)
c - Take control (teleop free, recording correction)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/rac_data_collection_openarms.py \
--robot.type=openarms_follower \
--robot.port_right=can0 \
--robot.port_left=can1 \
--robot.cameras="{ left_wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}}" \
--teleop.type=openarms_mini \
--teleop.port_right=/dev/ttyUSB0 \
--teleop.port_left=/dev/ttyUSB1 \
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/rac_openarms_dataset \
--dataset.single_task="Pick up the cube"
"""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from pprint import pformat
from typing import Any
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
IdentityProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless, predict_action
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@dataclass
class RaCDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
reset_time_s: float = 30
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
rename_map: dict[str, str] = field(default_factory=dict)
@dataclass
class RaCConfig:
robot: RobotConfig
dataset: RaCDatasetConfig
teleop: TeleoperatorConfig
policy: PreTrainedConfig | None = None
display_data: bool = True
play_sounds: bool = True
resume: bool = False
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
def init_rac_keyboard_listener():
"""Initialize keyboard listener with RaC-specific controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot
"correction_active": False, # 'c' pressed - human controlling, recording correction
"in_reset": False, # True during reset period
"start_next_episode": False, # Signal to start next episode
}
if is_headless():
logging.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
# During reset: any action key starts next episode
if key == keyboard.Key.space or key == keyboard.Key.right:
print("\n[RaC] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, 'char') and key.char == 'c':
print("\n[RaC] Starting next episode...")
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
print("[RaC] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
# During episode
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
print(" Press 'c' or START to take control")
events["policy_paused"] = True
elif hasattr(key, 'char') and key.char == 'c':
if events["policy_paused"] and not events["correction_active"]:
print("\n[RaC] ▶ START pressed - taking control")
events["start_next_episode"] = True
elif key == keyboard.Key.right:
print("[RaC] → End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("[RaC] ← Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("[RaC] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
start_pedal_listener(events)
return listener, events
def start_pedal_listener(events: dict):
"""Start foot pedal listener thread if evdev is available."""
import threading
try:
from evdev import InputDevice, ecodes
except ImportError:
logging.info("[Pedal] evdev not installed - pedal support disabled")
return
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
KEY_LEFT = "KEY_A" # Left pedal
KEY_RIGHT = "KEY_C" # Right pedal
def pedal_reader():
try:
dev = InputDevice(PEDAL_DEVICE)
print(f"[Pedal] Connected: {dev.name}")
print(f"[Pedal] Right=pause/next, Left=take control/start")
for ev in dev.read_loop():
if ev.type != ecodes.EV_KEY:
continue
from evdev import categorize
key = categorize(ev)
code = key.keycode
if isinstance(code, (list, tuple)):
code = code[0]
# Only trigger on key down
if key.keystate != 1:
continue
if events["in_reset"]:
# During reset: either pedal starts next episode
if code in [KEY_LEFT, KEY_RIGHT]:
print("\n[Pedal] Starting next episode...")
events["start_next_episode"] = True
else:
# During episode
if code == KEY_RIGHT:
# Right pedal: SPACE (pause) when running, → (next) when in correction
if events["correction_active"]:
print("\n[Pedal] → End episode")
events["exit_early"] = True
elif not events["policy_paused"]:
print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot")
print(" Press left pedal to take control")
events["policy_paused"] = True
elif code == KEY_LEFT:
# Left pedal: START (take control) when paused
if events["policy_paused"] and not events["correction_active"]:
print("\n[Pedal] ▶ START pressed - taking control")
events["start_next_episode"] = True
except FileNotFoundError:
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
except PermissionError:
logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}")
except Exception as e:
logging.debug(f"[Pedal] Error: {e}")
thread = threading.Thread(target=pedal_reader, daemon=True)
thread.start()
def make_identity_processors():
"""Create identity processors for RaC recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, robot_proc, obs_proc
def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move all robot joints to zero position."""
obs = robot.get_observation()
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
target_pos = {k: 0.0 for k in current_pos}
print(f"[RaC] Moving robot to zero position ({duration_s}s)...")
steps = int(duration_s * fps)
for step in range(steps + 1):
t = step / steps
interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos}
robot.send_action(interp_pos)
time.sleep(1 / fps)
print("[RaC] Robot at zero position.")
@safe_stop_image_writer
def rac_rollout_loop(
robot: Robot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
dataset: LeRobotDataset,
events: dict,
fps: int,
control_time_s: float,
single_task: str,
display_data: bool = True,
) -> dict:
"""
RaC rollout loop with two-stage intervention:
1. Policy runs autonomously (recording) - teleop free/idle
2. SPACE: Policy pauses, teleop mirrors robot position (NOT recording)
3. 'c': Human takes control, teleop torque disabled (recording correction)
4. →: End episode
This allows smooth handoff - teleop tracks robot only when paused.
"""
policy.reset()
preprocessor.reset()
postprocessor.reset()
device = get_safe_torch_device(policy.config.device)
frame_buffer = []
stats = {
"total_frames": 0,
"autonomous_frames": 0,
"paused_frames": 0,
"correction_frames": 0,
}
# Start with teleop torque disabled - only enable when paused to track robot
teleop.disable_torque()
was_paused = False
was_correction_active = False
waiting_for_takeover = False
timestamp = 0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Detect transition to paused state - smooth move teleop to robot position
if events["policy_paused"] and not was_paused:
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
print("[RaC] Moving teleop to robot position (2s smooth transition)...")
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
print("[RaC] Teleop aligned. Press START to take control.")
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
# Wait for start button before enabling correction mode
if waiting_for_takeover and events["start_next_episode"]:
print("[RaC] Start pressed - enabling teleop control...")
teleop.disable_torque()
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
was_correction_active = True
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
if events["correction_active"]:
# Human controlling - record correction data
robot_action = teleop.get_action()
# Convert gripper from teleop range (0-100) to robot degrees (-65 to 0)
for key in robot_action:
if "gripper" in key:
robot_action[key] = -0.65 * robot_action[key]
robot.send_action(robot_action)
stats["correction_frames"] += 1
# Record this frame
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task}
frame_buffer.append(frame)
stats["total_frames"] += 1
elif waiting_for_takeover:
# Waiting for START - policy stopped, no recording, robot holds position
stats["paused_frames"] += 1
elif events["policy_paused"]:
# Paused and user acknowledged - teleop tracks robot position, don't record
robot_action = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
teleop.send_feedback(robot_action)
stats["paused_frames"] += 1
else:
# Normal policy execution - record (teleop is free/idle)
action_values = predict_action(
observation=obs_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
robot.send_action(robot_action)
stats["autonomous_frames"] += 1
# Record this frame
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task}
frame_buffer.append(frame)
stats["total_frames"] += 1
if display_data:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt)
timestamp = time.perf_counter() - start_t
# Ensure teleoperator torque is disabled at end
teleop.disable_torque()
for frame in frame_buffer:
dataset.add_frame(frame)
return stats
def reset_loop(
robot: Robot,
teleop: Teleoperator,
events: dict,
fps: int,
):
"""Reset period where human repositions environment. Two-stage: enable teleop, then start episode."""
print("\n" + "=" * 65)
print(" [RaC] RESET - Moving teleop to robot position...")
print("=" * 65)
# Enter reset mode
events["in_reset"] = True
events["start_next_episode"] = False
# First move teleop to match robot position to avoid sudden jumps
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
# Stage 1: Wait for user to press start to enable teleoperation
print(" Teleop aligned. Press any key/pedal to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
# Stage 2: Enable teleop and let user move robot to starting position
events["start_next_episode"] = False
teleop.disable_torque()
print(" Teleop enabled - move robot to starting position")
print(" Press any key/pedal to start next episode")
# Wait for user to signal ready for next episode
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
# Convert gripper from teleop range (0-100) to robot degrees (-65 to 0)
for key in action:
if "gripper" in key:
action[key] = -0.65 * action[key]
robot.send_action(action)
dt = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt)
# Exit reset mode and clear flags for next episode
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
@parser.wrap()
def rac_collect(cfg: RaCConfig) -> LeRobotDataset:
"""Main RaC data collection function."""
init_logging()
logging.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="rac_collection_openarms")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, robot_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and robot.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
teleop.connect()
listener, events = init_rac_keyboard_listener()
print("\n" + "=" * 65)
print(" RaC (Recovery and Correction) Data Collection - OpenArms")
print("=" * 65)
print(" Policy runs autonomously until you intervene.")
print()
print(" Controls:")
print(" SPACE - Pause policy (teleop tracks robot, no recording)")
print(" c - Take control (start correction, recording)")
print(" → - End episode (save)")
print(" ← - Re-record episode")
print(" ESC - Stop session and push to hub")
print("=" * 65 + "\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps)
stats = rac_rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
fps=cfg.dataset.fps,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
logging.info(f"Episode stats: {stats}")
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
# Reset between episodes
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(
robot=robot,
teleop=teleop,
events=events,
fps=cfg.dataset.fps,
)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
rac_collect()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,902 @@
#!/usr/bin/env python
"""
RaC (Recovery and Correction) Data Collection for OpenArms Robot with RTC.
This combines RaC data collection with Real-Time Chunking (RTC) for smooth policy execution.
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion
despite high inference latency by asynchronously generating action chunks.
The workflow:
1. Policy runs autonomously with RTC (teleop is idle/free)
2. Press SPACE to pause - teleop moves to match robot position
3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Usage:
python examples/rac/rac_data_collection_openarms_rtc.py \
--robot.port_right=can0 \
--robot.port_left=can1 \
--teleop.port_right=/dev/ttyUSB0 \
--teleop.port_left=/dev/ttyUSB1 \
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/rac_openarms_dataset \
--dataset.single_task="Pick up the cube"
"""
import logging
import math
import time
from dataclasses import dataclass, field
from pathlib import Path
from pprint import pformat
from threading import Event, Lock, Thread
from typing import Any
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
IdentityProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless, predict_action
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# Configuration
# ============================================================================
@dataclass
class RaCRTCDatasetConfig:
repo_id: str = "lerobot/rac_openarms_rtc"
single_task: str = "default task"
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 500
reset_time_s: float = 30
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
streaming_encoding: bool = True
rename_map: dict[str, str] = field(default_factory=dict)
@dataclass
class RaCRTCConfig:
robot: RobotConfig = field(default_factory=lambda: OpenArmsFollowerConfig(
port_left="can0",
port_right="can1",
))
teleop: TeleoperatorConfig = field(default_factory=lambda: OpenArmsMiniConfig(
port_left="/dev/ttyUSB1",
port_right="/dev/ttyUSB0",
))
dataset: RaCRTCDatasetConfig = field(default_factory=RaCRTCDatasetConfig)
policy: PreTrainedConfig | None = None
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(
enabled=True,
execution_horizon=20,
max_guidance_weight=5.0,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
))
interpolation: bool = True
display_data: bool = True
play_sounds: bool = True
resume: bool = False
device: str = "cuda"
action_queue_size_to_get_new_actions: int = 30
# Torch compile is disabled by default for real-time inference
# First inference with compile takes minutes to compile kernels
use_torch_compile: bool = False
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
# ============================================================================
# Thread-Safe Robot Wrapper (from evaluate_with_rtc.py)
# ============================================================================
class RobotWrapper:
"""Thread-safe wrapper for robot operations."""
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: dict) -> None:
with self.lock:
self.robot.send_action(action)
@property
def observation_features(self) -> dict:
return self.robot.observation_features
@property
def action_features(self) -> dict:
return self.robot.action_features
@property
def name(self) -> str:
return self.robot.name
@property
def robot_type(self) -> str:
return self.robot.robot_type
# ============================================================================
# Keyboard/Pedal Listeners
# ============================================================================
def init_rac_keyboard_listener():
"""Initialize keyboard listener with RaC-specific controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logging.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key == keyboard.Key.space or key == keyboard.Key.right:
print("\n[RaC] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, 'char') and key.char == 'c':
print("\n[RaC] Starting next episode...")
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
print("[RaC] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
print(" Press 'c' or START to take control")
events["policy_paused"] = True
elif hasattr(key, 'char') and key.char == 'c':
if events["policy_paused"] and not events["correction_active"]:
print("\n[RaC] ▶ START pressed - taking control")
events["start_next_episode"] = True
elif key == keyboard.Key.right:
print("[RaC] → End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("[RaC] ← Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("[RaC] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
start_pedal_listener(events)
return listener, events
def start_pedal_listener(events: dict):
"""Start foot pedal listener thread if evdev is available."""
import threading
try:
from evdev import InputDevice, ecodes # noqa: F401
except ImportError:
logging.info("[Pedal] evdev not installed - pedal support disabled")
return
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
KEY_LEFT = "KEY_A"
KEY_RIGHT = "KEY_C"
def pedal_reader():
try:
dev = InputDevice(PEDAL_DEVICE)
print(f"[Pedal] Connected: {dev.name}")
for ev in dev.read_loop():
if ev.type != ecodes.EV_KEY:
continue
from evdev import categorize # noqa: F401
key = categorize(ev)
code = key.keycode
if isinstance(code, (list, tuple)):
code = code[0]
if key.keystate != 1:
continue
if events["in_reset"]:
if code in [KEY_LEFT, KEY_RIGHT]:
events["start_next_episode"] = True
else:
if code == KEY_RIGHT:
if events["correction_active"]:
events["exit_early"] = True
elif not events["policy_paused"]:
events["policy_paused"] = True
elif code == KEY_LEFT:
if events["policy_paused"] and not events["correction_active"]:
events["start_next_episode"] = True
except FileNotFoundError:
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
except PermissionError:
logging.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}")
except Exception as e:
logging.debug(f"[Pedal] Error: {e}")
thread = threading.Thread(target=pedal_reader, daemon=True)
thread.start()
def make_identity_processors():
"""Create identity processors for RaC recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, robot_proc, obs_proc
# ============================================================================
# RTC Inference Thread (from evaluate_with_rtc.py)
# ============================================================================
def rtc_inference_thread(
policy,
obs_holder: dict,
hw_features: dict,
preprocessor,
postprocessor,
queue_holder: dict,
shutdown_event: Event,
policy_active: Event,
cfg: RaCRTCConfig,
):
"""Background thread that generates action chunks using RTC."""
try:
logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========")
logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys")
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / cfg.dataset.fps
policy_device = policy.config.device
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
inference_count = 0
wait_logged = False
while not shutdown_event.is_set():
if not policy_active.is_set():
if not wait_logged:
logger.info("[RTC] Waiting for policy_active...")
wait_logged = True
time.sleep(0.01)
continue
wait_logged = False
action_queue = queue_holder["queue"]
if action_queue is None:
logger.warning("[RTC] queue_holder['queue'] is None!")
time.sleep(0.01)
continue
obs_filtered = obs_holder.get("obs")
if obs_filtered is None:
logger.warning("[RTC] obs_holder['obs'] is None!")
time.sleep(0.01)
continue
qsize = action_queue.qsize()
if qsize <= get_actions_threshold:
try:
if inference_count == 0:
logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}")
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255
obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous()
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device)
obs_with_policy_features["task"] = [cfg.dataset.single_task]
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
preprocessed_obs = preprocessor(obs_with_policy_features)
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference)
inference_count += 1
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}")
except Exception as e:
logger.error(f"[RTC] Inference error: {e}")
import traceback
traceback.print_exc()
time.sleep(1.0)
else:
time.sleep(0.01)
logger.info("[RTC] Inference thread shutting down")
except Exception as e:
logger.error(f"[RTC] THREAD CRASHED: {e}")
import traceback
traceback.print_exc()
# ============================================================================
# Main Rollout Loop
# ============================================================================
@safe_stop_image_writer
def rac_rtc_rollout_loop(
robot: RobotWrapper,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor,
postprocessor,
dataset: LeRobotDataset,
events: dict,
cfg: RaCRTCConfig,
queue_holder: dict,
obs_holder: dict, # Main loop writes obs here for RTC thread to read
policy_active: Event,
hw_features: dict,
) -> dict:
"""RaC rollout loop with RTC for smooth policy execution."""
fps = cfg.dataset.fps
single_task = cfg.dataset.single_task
control_time_s = cfg.dataset.episode_time_s
device = get_safe_torch_device(cfg.device)
# Reset policy state
policy.reset()
preprocessor.reset()
postprocessor.reset()
streaming = dataset._streaming_encoder is not None
frame_buffer = [] if not streaming else None
stats = {
"total_frames": 0,
"autonomous_frames": 0,
"paused_frames": 0,
"correction_frames": 0,
}
teleop.disable_torque()
was_paused = False
waiting_for_takeover = False
# Action keys for converting tensor to dict
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
# Interpolation state
prev_action: Tensor | None = None
interpolated_actions: list[Tensor] = []
interp_idx = 0
if cfg.interpolation:
control_interval = 1.0 / (fps * 2) # 2x rate
else:
control_interval = 1.0 / fps
robot_action = {}
timestamp = 0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# State transition: entering paused state
if events["policy_paused"] and not was_paused:
policy_active.clear() # Stop RTC inference
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
print("[RaC] Moving teleop to robot position...")
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
print("[RaC] Teleop aligned. Press 'c' to take control.")
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
# Reset interpolation
prev_action = None
interpolated_actions = []
interp_idx = 0
# Wait for takeover
if waiting_for_takeover and events["start_next_episode"]:
print("[RaC] Taking control...")
teleop.disable_torque()
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
# Get observation (ONLY the main loop reads from robot!)
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
# Share observation with RTC thread (thread reads, main loop writes)
obs_holder["obs"] = obs_filtered
if events["correction_active"]:
# Human controlling
robot_action = teleop.get_action()
for key in robot_action:
if "gripper" in key:
robot_action[key] = -0.65 * robot_action[key]
robot.send_action(robot_action)
stats["correction_frames"] += 1
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task}
if streaming:
dataset.add_frame(frame)
else:
frame_buffer.append(frame)
stats["total_frames"] += 1
elif waiting_for_takeover:
stats["paused_frames"] += 1
elif events["policy_paused"]:
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
teleop.send_feedback(robot_pos)
stats["paused_frames"] += 1
else:
# Policy execution with RTC
if not policy_active.is_set():
policy_active.set()
logger.info("[ROLLOUT] Policy activated, waiting for first actions...")
action_queue = queue_holder["queue"]
# Get action from queue (with interpolation)
if interp_idx >= len(interpolated_actions):
new_action = action_queue.get() if action_queue else None
# Log queue status periodically
if stats["autonomous_frames"] == 0 and new_action is None:
qsize = action_queue.qsize() if action_queue else -1
if timestamp < 0.5 or int(timestamp * 10) % 10 == 0:
logger.info(f"[ROLLOUT] Waiting for actions... queue_size={qsize}, obs_set={obs_holder.get('obs') is not None}")
if new_action is not None:
current_action = new_action.cpu()
if cfg.interpolation and prev_action is not None:
mid = prev_action + 0.5 * (current_action - prev_action)
interpolated_actions = [mid, current_action]
else:
interpolated_actions = [current_action]
prev_action = current_action
interp_idx = 0
if stats["autonomous_frames"] == 0:
logger.info(f"[ROLLOUT] Got first action! Starting robot motion.")
if interp_idx < len(interpolated_actions):
action_to_send = interpolated_actions[interp_idx]
interp_idx += 1
robot_action = {}
for i, key in enumerate(action_keys):
if i < len(action_to_send):
robot_action[key] = action_to_send[i].item()
robot.send_action(robot_action)
stats["autonomous_frames"] += 1
# Record at original fps
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task}
if streaming:
dataset.add_frame(frame)
else:
frame_buffer.append(frame)
stats["total_frames"] += 1
if cfg.display_data:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
sleep_time = control_interval - dt
if sleep_time > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t
policy_active.clear()
teleop.disable_torque()
if not streaming:
for frame in frame_buffer:
dataset.add_frame(frame)
return stats
def reset_loop(robot: RobotWrapper, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
print("\n" + "=" * 65)
print(" [RaC] RESET")
print("=" * 65)
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
print(" Press any key/pedal to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop.disable_torque()
print(" Teleop enabled - press any key/pedal to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
for key in action:
if "gripper" in key:
action[key] = -0.65 * action[key]
robot.send_action(action)
dt = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt)
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
# ============================================================================
# Main Entry Point
# ============================================================================
@parser.wrap()
def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
"""Main RaC data collection function with RTC."""
init_logging()
logging.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="rac_rtc_collection_openarms")
robot_raw = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, robot_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot_raw.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot_raw.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
shutdown_event = Event()
policy_active = Event()
rtc_thread = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if cfg.dataset.streaming_encoding:
dataset.start_streaming_encoder()
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot_raw.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
streaming_encoding=cfg.dataset.streaming_encoding,
)
# Load policy
logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
policy_class = get_policy_class(cfg.policy.type)
# Override compile_model for real-time inference (first compile takes minutes)
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type in ["pi05", "pi0"]:
policy_config.compile_model = cfg.use_torch_compile
logger.info(f"Set compile_model={cfg.use_torch_compile} for real-time inference")
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
policy.config.rtc_config = cfg.rtc
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
logger.info(f"Policy loaded: {policy.name}")
# Setup preprocessor/postprocessor
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Connect robot and wrap for thread safety
robot_raw.connect()
robot = RobotWrapper(robot_raw)
teleop.connect()
listener, events = init_rac_keyboard_listener()
# Shared state holders (main loop writes, RTC thread reads)
queue_holder = {"queue": ActionQueue(cfg.rtc)}
obs_holder = {"obs": None, "robot_type": robot.robot_type} # Main loop updates obs
# Start RTC inference thread
# NOTE: Thread does NOT access robot directly - reads from obs_holder
rtc_thread = Thread(
target=rtc_inference_thread,
args=(
policy,
obs_holder, # Thread reads obs from here (set by main loop)
hw_features,
preprocessor,
postprocessor,
queue_holder,
shutdown_event,
policy_active,
cfg,
),
daemon=True,
name="RTCInference",
)
rtc_thread.start()
logger.info("Started RTC inference thread")
print("\n" + "=" * 65)
print(" RaC Data Collection with RTC")
print("=" * 65)
print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}")
print(f" FPS: {cfg.dataset.fps}")
print(f" Interpolation: {cfg.interpolation}")
print()
print(" Controls:")
print(" SPACE - Pause policy")
print(" c - Take control")
print(" → - End episode")
print(" ESC - Stop and push to hub")
print("=" * 65 + "\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
# Fresh action queue per episode (update holder so thread sees it)
queue_holder["queue"] = ActionQueue(cfg.rtc)
logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}")
stats = rac_rtc_rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
cfg=cfg,
queue_holder=queue_holder,
obs_holder=obs_holder,
policy_active=policy_active,
hw_features=hw_features,
)
logging.info(f"Episode stats: {stats}")
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
t_save_start = time.perf_counter()
dataset.save_episode()
logging.info(f"[RaC] save_episode total: {time.perf_counter() - t_save_start:.2f}s")
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(robot, teleop, events, cfg.dataset.fps)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
shutdown_event.set()
policy_active.clear()
if rtc_thread and rtc_thread.is_alive():
rtc_thread.join(timeout=2.0)
if dataset:
dataset.finalize()
if robot_raw.is_connected:
robot_raw.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
rac_rtc_collect()
if __name__ == "__main__":
main()

View File

@@ -142,38 +142,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="so100_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -182,24 +168,41 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -146,38 +146,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
@@ -185,25 +170,44 @@ def main():
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
dataset.finalize()
dataset.push_to_hub()
finally:
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -74,32 +74,35 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
# Clean up
robot.disconnect()
finally:
# Clean up
robot.disconnect()
if __name__ == "__main__":

10
loop_datasets.py Normal file
View File

@@ -0,0 +1,10 @@
from huggingface_hub import HfApi, list_datasets
api = HfApi()
datasets = list_datasets(author="lerobot-data-collection")
print('"[', end="")
i=0
for dataset in datasets:
if "three-folds-dataset" in dataset.id:
print("'" + dataset.id + "',", end="")
print(']"',)

View File

@@ -105,12 +105,17 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0"
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"matplotlib>=3.9.0,<4.0.0",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
@@ -355,9 +360,9 @@ ignore_errors = false
module = "lerobot.cameras.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.motors.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"

View File

@@ -13,5 +13,5 @@
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .utils import make_cameras_from_configs

View File

@@ -15,11 +15,12 @@
# limitations under the License.
import abc
import warnings
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
from .configs import CameraConfig, ColorMode
from .configs import CameraConfig
class Camera(abc.ABC):
@@ -30,20 +31,12 @@ class Camera(abc.ABC):
Manages basic camera properties (FPS, resolution) and core operations:
- Connection/disconnection
- Frame capture (sync/async)
- Frame capture (sync/async/latest)
Attributes:
fps (int | None): Configured frames per second
width (int | None): Frame width in pixels
height (int | None): Frame height in pixels
Example:
class MyCamera(Camera):
def __init__(self, config): ...
@property
def is_connected(self) -> bool: ...
def connect(self, warmup=True): ...
# Plus other required methods
"""
def __init__(self, config: CameraConfig):
@@ -56,6 +49,32 @@ class Camera(abc.ABC):
self.width: int | None = config.width
self.height: int | None = config.height
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
@@ -89,12 +108,10 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""Capture and return a single frame from the camera.
def read(self) -> NDArray[Any]:
"""Capture and return a single frame from the camera synchronously.
Args:
color_mode: Desired color mode for the output frame. If None,
uses the camera's default color mode.
This is a blocking call that will wait for the hardware and its SDK.
Returns:
np.ndarray: Captured frame as a numpy array.
@@ -103,17 +120,64 @@ class Camera(abc.ABC):
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
"""Asynchronously capture and return a single frame from the camera.
"""Return the most recent new frame.
This method retrieves the latest frame captured by the background thread.
If a new frame is already available in the buffer (captured since the last call),
it returns it immediately.
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
was already consumed by a previous `async_read` call.
Essentially, this method return the latest unconsumed frame, waiting if necessary
for a new one to arrive within the specified timeout.
Usage:
- Ideal for control loops where you want to ensure every processed frame
is fresh, effectively synchronizing your loop to the camera's FPS.
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
or if the camera is disconnected.
Args:
timeout_ms: Maximum time to wait for a frame in milliseconds.
Defaults to implementation-specific timeout.
timeout_ms: Maximum time to wait for a new frame in milliseconds.
Defaults to 200ms (0.2s).
Returns:
np.ndarray: Captured frame as a numpy array.
Raises:
TimeoutError: If no new frame arrives within `timeout_ms`.
"""
pass
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Usage:
Ideal for scenarios requiring zero latency or decoupled frequencies & when
we want a guaranteed frame, such as UI visualization, logging, or
non-critical monitoring.
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
NotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
warnings.warn(
f"{self.__class__.__name__}.read_latest() is not implemented. "
"Please override read_latest(); it will be required in future releases.",
FutureWarning,
stacklevel=2,
)
return self.async_read()
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnect from the camera and release resources."""

View File

@@ -25,6 +25,10 @@ class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
@@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum):
ROTATE_180 = 180
ROTATE_270 = -90
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
class Cv2Backends(int, Enum):
ANY = 0
V4L2 = 200
DSHOW = 700
PVAPI = 800
ANDROID = 1000
AVFOUNDATION = 1200
MSMF = 1400
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus

View File

@@ -32,10 +32,11 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_backend, get_cv2_rotation
from ..utils import get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
@@ -70,34 +71,24 @@ class OpenCVCamera(Camera):
Example:
```python
from lerobot.cameras.opencv import OpenCVCamera
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
# Basic usage with camera index 0
config = OpenCVCameraConfig(index_or_path=0)
camera = OpenCVCamera(config)
camera.connect()
# Read 1 frame synchronously
# Read 1 frame synchronously (blocking)
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# When done, properly disconnect the camera using
camera.disconnect()
# Example with custom settings
custom_config = OpenCVCameraConfig(
index_or_path='/dev/video0', # Or use an index
fps=30,
width=1280,
height=720,
color_mode=ColorMode.RGB,
rotation=Cv2Rotation.ROTATE_90
)
custom_camera = OpenCVCamera(custom_config)
# ... connect, read, disconnect ...
```
"""
@@ -123,10 +114,11 @@ class OpenCVCamera(Camera):
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = get_cv2_backend()
self.backend: int = config.backend
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
@@ -141,20 +133,23 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
Initializes the OpenCV VideoCapture object, sets desired camera properties
(FPS, width, height), and performs initial checks.
(FPS, width, height), starts the background reading thread and performs initial checks.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
@@ -170,15 +165,20 @@ class OpenCVCamera(Camera):
)
self._configure_capture_settings()
self._start_read_thread()
if warmup:
if warmup and self.warmup_s > 0:
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
self.async_read(timeout_ms=self.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
logger.info(f"{self} connected.")
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
@@ -196,11 +196,8 @@ class OpenCVCamera(Camera):
Raises:
RuntimeError: If the camera fails to set any of the specified properties
to the requested value.
DeviceNotConnectedError: If the camera is not connected when attempting
to configure settings.
DeviceNotConnectedError: If the camera is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
@@ -339,6 +336,18 @@ class OpenCVCamera(Camera):
return found_cameras_info
def _read_from_hardware(self) -> NDArray[Any]:
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
ret, frame = self.videocapture.read()
if not ret:
raise RuntimeError(f"{self} read failed (status={ret}).")
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -346,11 +355,6 @@ class OpenCVCamera(Camera):
This is a blocking call. It waits for the next available frame from the
camera hardware via OpenCV.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
@@ -362,34 +366,31 @@ class OpenCVCamera(Camera):
received frame dimensions don't match expectations before rotation.
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
ret, frame = self.videocapture.read()
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
processed_frame = self._postprocess_image(frame, color_mode)
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return processed_frame
return frame
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
Args:
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame.
@@ -399,11 +400,10 @@ class OpenCVCamera(Camera):
RuntimeError: If the raw frame dimensions do not match the configured
`width` and `height`.
"""
requested_color_mode = self.color_mode if color_mode is None else color_mode
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
h, w, c = image.shape
@@ -417,7 +417,7 @@ class OpenCVCamera(Camera):
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
processed_image = image
if requested_color_mode == ColorMode.RGB:
if self.color_mode == ColorMode.RGB:
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
@@ -431,7 +431,7 @@ class OpenCVCamera(Camera):
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -439,30 +439,37 @@ class OpenCVCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
color_image = self.read()
raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame)
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = color_image
self.latest_frame = processed_frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self._stop_read_thread()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
@@ -475,6 +482,12 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -482,6 +495,7 @@ class OpenCVCamera(Camera):
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is “best effort” under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -496,17 +510,14 @@ class OpenCVCamera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
f"Read thread alive: {self.thread.is_alive()}."
)
with self.frame_lock:
@@ -518,6 +529,41 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera and cleans up resources.
@@ -538,4 +584,9 @@ class OpenCVCamera(Camera):
self.videocapture.release()
self.videocapture = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")

View File

@@ -15,9 +15,9 @@
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
@CameraConfig.register_subclass("opencv")
@@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
@@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
backend: Cv2Backends = Cv2Backends.ANY
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
self.backend = Cv2Backends(self.backend)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(

View File

@@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig):
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.color_mode = ColorMode(self.color_mode)

View File

@@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
@@ -80,6 +81,8 @@ class Reachy2Camera(Camera):
self.config = config
self.color_mode = config.color_mode
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.cam_manager: CameraManager | None = None
@@ -121,16 +124,12 @@ class Reachy2Camera(Camera):
"""
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
This method retrieves the most recent frame available in Reachy 2's low-level software.
Returns:
np.ndarray: The captured frame as a NumPy array in the format
@@ -139,12 +138,14 @@ class Reachy2Camera(Camera):
"""
start_time = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
@@ -165,25 +166,27 @@ class Reachy2Camera(Camera):
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8)
raise RuntimeError(f"Internal error: No frame available for {self}.")
if self.config.color_mode == "rgb":
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if self.color_mode == ColorMode.RGB:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.latest_frame = frame
self.latest_timestamp = time.perf_counter()
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame.
This method retrieves the most recent frame available in Reachy 2's low-level software.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Same as read()
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
@@ -194,16 +197,40 @@ class Reachy2Camera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
frame = self.read()
return self.read()
if frame is None:
raise RuntimeError(f"Internal error: No frame available for {self}.")
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
return frame
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
tuple[NDArray, float]:
- The frame image (numpy array).
- The timestamp (time.perf_counter) when this frame was captured.
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return self.latest_frame
@check_if_not_connected
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -211,8 +238,6 @@ class Reachy2Camera(Camera):
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.cam_manager is not None:
self.cam_manager.disconnect()

View File

@@ -30,7 +30,8 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -72,15 +73,14 @@ class RealSenseCamera(Camera):
camera = RealSenseCamera(config)
camera.connect()
# Read 1 frame synchronously
# Read 1 frame synchronously (blocking)
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# When done, properly disconnect the camera using
camera.disconnect()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# Example with depth capture and custom settings
custom_config = RealSenseCameraConfig(
@@ -133,7 +133,9 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_color_frame: NDArray[Any] | None = None
self.latest_depth_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -151,6 +153,7 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -158,14 +161,16 @@ class RealSenseCamera(Camera):
Initializes the RealSense pipeline, configures the required streams (color
and optionally depth), starts the pipeline, and validates the actual stream settings.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
@@ -181,15 +186,18 @@ class RealSenseCamera(Camera):
) from e
self._configure_capture_settings()
self._start_read_thread()
if warmup:
time.sleep(
1
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
time.sleep(0.1)
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
self.warmup_s = max(self.warmup_s, 1)
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.async_read(timeout_ms=self.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
logger.info(f"{self} connected.")
@@ -282,6 +290,7 @@ class RealSenseCamera(Camera):
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
@@ -291,8 +300,6 @@ class RealSenseCamera(Camera):
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
@@ -312,6 +319,7 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
@check_if_not_connected
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -319,9 +327,6 @@ class RealSenseCamera(Camera):
This is a blocking call. It waits for a coherent set of frames (depth)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
@@ -330,44 +335,50 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
"""
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.use_depth:
raise RuntimeError(
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
start_time = time.perf_counter()
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self.new_frame_event.clear()
_ = self.async_read(timeout_ms=10000)
with self.frame_lock:
depth_map = self.latest_depth_frame
if depth_map is None:
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
return depth_map
def _read_from_hardware(self):
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
if not ret or frame is None:
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
raise RuntimeError(f"{self} read failed (status={ret}).")
depth_frame = frame.get_depth_frame()
depth_map = np.asanyarray(depth_frame.get_data())
return frame
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
This is a blocking call. It waits for a coherent set of frames (color)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The captured color frame as a NumPy array
(height, width, channels), processed according to `color_mode` and rotation.
@@ -378,39 +389,36 @@ class RealSenseCamera(Camera):
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
color_frame = frame.get_color_frame()
color_image_raw = np.asanyarray(color_frame.get_data())
self.new_frame_event.clear()
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return color_image_processed
return frame
def _postprocess_image(
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
) -> NDArray[Any]:
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
Args:
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
@@ -421,9 +429,9 @@ class RealSenseCamera(Camera):
`width` and `height`.
"""
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if depth_frame:
@@ -454,7 +462,7 @@ class RealSenseCamera(Camera):
On each iteration:
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame (thread-safe)
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -462,25 +470,41 @@ class RealSenseCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame()
color_frame = np.asanyarray(color_frame_raw.get_data())
processed_color_frame = self._postprocess_image(color_frame)
if self.use_depth:
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = color_image
self.latest_color_frame = processed_color_frame
if self.use_depth:
self.latest_depth_frame = processed_depth_frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self._stop_read_thread()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
@@ -498,7 +522,14 @@ class RealSenseCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -506,6 +537,7 @@ class RealSenseCamera(Camera):
This method retrieves the most recent color frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is “best effort” under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -520,21 +552,18 @@ class RealSenseCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
f"Read thread alive: {self.thread.is_alive()}."
)
with self.frame_lock:
frame = self.latest_frame
frame = self.latest_color_frame
self.new_frame_event.clear()
if frame is None:
@@ -542,6 +571,42 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_color_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
@@ -565,4 +630,10 @@ class RealSenseCamera(Camera):
self.rs_pipeline = None
self.rs_profile = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")

View File

@@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from typing import cast
from lerobot.utils.import_utils import make_device_from_device_class
@@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)

View File

@@ -34,7 +34,8 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -45,6 +46,12 @@ logger = logging.getLogger(__name__)
class ZMQCamera(Camera):
"""
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
incoming JSON messages containing Base64 encoded images. It supports both
synchronous and asynchronous frame reading patterns.
Example usage:
```python
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
@@ -52,7 +59,16 @@ class ZMQCamera(Camera):
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
camera = ZMQCamera(config)
camera.connect()
frame = camera.read()
# Read 1 frame synchronously (blocking)
color_image = camera.read()
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
camera.disconnect()
```
"""
@@ -68,14 +84,17 @@ class ZMQCamera(Camera):
self.color_mode = config.color_mode
self.timeout_ms = config.timeout_ms
# ZMQ Context and Socket
self.context: zmq.Context | None = None
self.socket: zmq.Socket | None = None
self._connected = False
# Threading resources
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -83,12 +102,17 @@ class ZMQCamera(Camera):
@property
def is_connected(self) -> bool:
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server."""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
"""Connect to ZMQ camera server.
Args:
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
logger.info(f"Connecting to {self}...")
@@ -103,17 +127,28 @@ class ZMQCamera(Camera):
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
self._connected = True
# Auto-detect resolution
# Auto-detect resolution if not provided
if self.width is None or self.height is None:
h, w = self.read().shape[:2]
# Read directly from hardware because the thread isn't running yet
temp_frame = self._read_from_hardware()
h, w = temp_frame.shape[:2]
self.height = h
self.width = w
logger.info(f"{self} resolution: {w}x{h}")
logger.info(f"{self} resolution detected: {w}x{h}")
self._start_read_thread()
logger.info(f"{self} connected.")
if warmup:
time.sleep(0.1)
# Ensure we have captured at least one frame via the thread
start_time = time.time()
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
self.async_read(timeout_ms=self.config.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
except Exception as e:
self._cleanup()
@@ -131,15 +166,14 @@ class ZMQCamera(Camera):
@staticmethod
def find_cameras() -> list[dict[str, Any]]:
"""ZMQ cameras require manual configuration (server address/port)."""
return []
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Read a single frame from the ZMQ camera.
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
"""
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
Returns:
np.ndarray: Decoded frame (height, width, 3)
def _read_from_hardware(self) -> NDArray[Any]:
"""
Reads a single frame directly from the ZMQ socket.
"""
if not self.is_connected or self.socket is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -147,6 +181,7 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except Exception as e:
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
@@ -176,42 +211,114 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
This is a blocking call. It waits for the next available frame from the
camera background thread.
Returns:
np.ndarray: Decoded frame (height, width, 3)
"""
start_time = time.perf_counter()
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self) -> None:
while self.stop_event and not self.stop_event.is_set():
"""
Internal loop run by the background thread for asynchronous reading.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.")
failure_count = 0
while not self.stop_event.is_set():
try:
frame = self.read()
frame = self._read_from_hardware()
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except TimeoutError:
pass
except Exception as e:
logger.warning(f"Read error: {e}")
except (TimeoutError, Exception) as e:
if failure_count <= 10:
failure_count += 1
logger.warning(f"Read error: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
if self.thread and self.thread.is_alive():
return
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, daemon=True)
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
if self.stop_event:
if self.stop_event is not None:
self.stop_event.set()
if self.thread and self.thread.is_alive():
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
"""Read latest frame asynchronously (non-blocking)."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
if not self.thread or not self.thread.is_alive():
self._start_read_thread()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms.
Returns:
np.ndarray: The latest captured frame.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
@@ -225,11 +332,54 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""Disconnect from ZMQ camera."""
if not self.is_connected and not self.thread:
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
self._stop_read_thread()
if self.thread is not None:
self._stop_read_thread()
self._cleanup()
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")

View File

@@ -29,12 +29,10 @@ class ZMQCameraConfig(CameraConfig):
camera_name: str = "zmq_camera"
color_mode: ColorMode = ColorMode.RGB
timeout_ms: int = 5000
warmup_s: int = 1
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
self.color_mode = ColorMode(self.color_mode)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")

View File

@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
"""
n_obs_steps: int = 1

View File

@@ -116,6 +116,9 @@ def update_meta_data(
Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset.
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
to correctly map source file indices to their destination locations.
Args:
df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata.
@@ -129,8 +132,50 @@ def update_meta_data(
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
# Update data file indices using source-to-destination mapping
# This is critical for handling datasets that are already results of a merge
data_src_to_dst = data_idx.get("src_to_dst", {})
if data_src_to_dst:
# Store original indices for lookup
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
df["_orig_data_file"] = df["data/file_index"].copy()
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
# This is much faster than per-row iteration for large metadata tables
mapping_index = pd.MultiIndex.from_tuples(
list(data_src_to_dst.keys()),
names=["chunk_index", "file_index"],
)
mapping_values = list(data_src_to_dst.values())
mapping_df = pd.DataFrame(
mapping_values,
index=mapping_index,
columns=["dst_chunk", "dst_file"],
)
# Construct a MultiIndex for each row based on original data indices
row_index = pd.MultiIndex.from_arrays(
[df["_orig_data_chunk"], df["_orig_data_file"]],
names=["chunk_index", "file_index"],
)
# Align mapping to rows; missing keys fall back to the default destination
reindexed = mapping_df.reindex(row_index)
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
)
# Assign mapped destination indices back to the DataFrame
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
df["data/file_index"] = reindexed["dst_file"].to_numpy()
# Clean up temporary columns
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
else:
# Fallback to simple offset (backward compatibility for single-file sources)
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
@@ -146,8 +191,7 @@ def update_meta_data(
if src_to_dst:
# Map each episode to its correct destination file and apply offset
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
# Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
@@ -163,8 +207,7 @@ def update_meta_data(
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
@@ -262,6 +305,10 @@ def aggregate_datasets(
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
# Clear the src_to_dst mapping after processing each source dataset
# to avoid interference between different source datasets
data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
@@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
dst_file_durations = video_idx["dst_file_durations"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
# Convert to Python int to ensure consistent dict keys
src_chunk_idx = int(src_chunk_idx)
src_file_idx = int(src_file_idx)
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
@@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation.
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
which is critical for correctly updating episode metadata when source datasets
have multiple data files (e.g., from a previous merge operation).
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
# Track source to destination file mapping for metadata update
# This is critical for handling datasets that are already results of a merge
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
@@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
data_idx = append_or_create_parquet_file(
# Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
df,
src_path,
data_idx,
@@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
hf_features=hf_features,
)
# Record the mapping from source to actual destination
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
# Add the mapping to data_idx for use in metadata update
data_idx["src_to_dst"] = src_to_dst
return data_idx
@@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
videos_idx,
)
meta_idx = append_or_create_parquet_file(
meta_idx, _ = append_or_create_parquet_file(
df,
src_path,
meta_idx,
@@ -501,7 +562,7 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
):
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
Manages file rotation when size limits are exceeded to prevent individual files
@@ -519,9 +580,11 @@ def append_or_create_parquet_file(
hf_features: Optional HuggingFace Features schema for proper image typing.
Returns:
dict: Updated index dictionary with current chunk and file indices.
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to.
"""
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
@@ -529,14 +592,15 @@ def append_or_create_parquet_file(
to_parquet_with_hf_images(df, dst_path, features=hf_features)
else:
df.to_parquet(dst_path)
return idx
return idx, (dst_chunk, dst_file)
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
@@ -555,7 +619,7 @@ def append_or_create_parquet_file(
else:
final_df.to_parquet(target_path)
return idx
return idx, (dst_chunk, dst_file)
def finalize_aggregation(aggr_meta, all_metadata):

View File

@@ -13,6 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from lerobot.datasets.utils import load_image_as_numpy
@@ -227,19 +231,20 @@ def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_si
return img[:, ::downsample_factor, ::downsample_factor]
def _load_single_image(path: str) -> np.ndarray:
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
return auto_downsample_height_width(img)
def sample_images(image_paths: list[str]) -> np.ndarray:
sampled_indices = sample_indices(len(image_paths))
paths = [image_paths[idx] for idx in sampled_indices]
images = None
for i, idx in enumerate(sampled_indices):
path = image_paths[idx]
# we load as uint8 to reduce memory usage
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
with ThreadPoolExecutor(max_workers=min(8, len(paths))) as pool:
loaded = list(pool.map(_load_single_image, paths))
images = np.empty((len(loaded), *loaded[0].shape), dtype=np.uint8)
for i, img in enumerate(loaded):
images[i] = img
return images
@@ -504,27 +509,46 @@ def compute_episode_stats(
quantile_list = DEFAULT_QUANTILES
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue
def _compute_single_feature_stats(key, data):
t0 = time.perf_counter()
if features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3)
keepdims = True
kd = True
else:
ep_ft_array = data
axes_to_reduce = 0
keepdims = data.ndim == 1
kd = data.ndim == 1
ep_stats[key] = get_feature_stats(
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
)
stats = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=kd, quantile_list=quantile_list)
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
stats = {k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in stats.items()}
dt = time.perf_counter() - t0
if dt > 0.1:
logging.info(f"[compute_episode_stats] {key} ({features[key]['dtype']}): {dt:.2f}s")
return key, stats
# Split into image/video features (heavy I/O) and numeric features (fast)
image_keys = [(k, d) for k, d in episode_data.items()
if k in features and features[k]["dtype"] in ["image", "video"]]
numeric_keys = [(k, d) for k, d in episode_data.items()
if k in features and features[k]["dtype"] not in ["image", "video", "string"]]
# Run image features in parallel (I/O bound)
if image_keys:
with ThreadPoolExecutor(max_workers=len(image_keys)) as pool:
futures = [pool.submit(_compute_single_feature_stats, k, d) for k, d in image_keys]
for f in futures:
key, stats = f.result()
ep_stats[key] = stats
# Numeric features are fast — run sequentially
for k, d in numeric_keys:
_, stats = _compute_single_feature_stats(k, d)
ep_stats[k] = stats
return ep_stats

View File

@@ -37,7 +37,7 @@ import torch
from tqdm import tqdm
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
@@ -1396,6 +1396,248 @@ BYTES_PER_KIB = 1024
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
def modify_tasks(
dataset: LeRobotDataset,
new_task: str | None = None,
episode_tasks: dict[int, str] | None = None,
) -> LeRobotDataset:
"""Modify tasks in a LeRobotDataset.
This function allows you to either:
1. Set a single task for the entire dataset (using `new_task`)
2. Set specific tasks for specific episodes (using `episode_tasks`)
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
specific episodes.
The dataset is modified in-place, updating only the task-related files:
- meta/tasks.parquet
- data/**/*.parquet (task_index column)
- meta/episodes/**/*.parquet (tasks column)
- meta/info.json (total_tasks)
Args:
dataset: The source LeRobotDataset to modify.
new_task: A single task string to apply to all episodes. If None and episode_tasks
is also None, raises an error.
episode_tasks: Optional dict mapping episode indices to their task strings.
Overrides `new_task` for specific episodes.
Examples:
Set a single task for all episodes:
dataset = modify_tasks(dataset, new_task="Pick up the cube")
Set different tasks for specific episodes:
dataset = modify_tasks(
dataset,
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
)
Set a default task with overrides:
dataset = modify_tasks(
dataset,
new_task="Default task",
episode_tasks={5: "Special task for episode 5"}
)
"""
if new_task is None and episode_tasks is None:
raise ValueError("Must specify at least one of new_task or episode_tasks")
if episode_tasks is not None:
valid_indices = set(range(dataset.meta.total_episodes))
invalid = set(episode_tasks.keys()) - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")
# Ensure episodes metadata is loaded
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.root)
# Build the mapping from episode index to task string
episode_to_task: dict[int, str] = {}
for ep_idx in range(dataset.meta.total_episodes):
if episode_tasks and ep_idx in episode_tasks:
episode_to_task[ep_idx] = episode_tasks[ep_idx]
elif new_task is not None:
episode_to_task[ep_idx] = new_task
else:
# Keep original task if not overridden and no default provided
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
if not original_tasks:
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
episode_to_task[ep_idx] = original_tasks[0]
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
logging.info(f"New tasks: {unique_tasks}")
root = dataset.root
# Update data files - modify task_index column
logging.info("Updating data files...")
data_dir = root / DATA_DIR
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
df = pd.read_parquet(parquet_path)
# Build a mapping from episode_index to new task_index for rows in this file
episode_indices_in_file = df["episode_index"].unique()
ep_to_new_task_idx = {
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
}
# Update task_index column
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
df.to_parquet(parquet_path, index=False)
# Update episodes metadata - modify tasks column
logging.info("Updating episodes metadata...")
episodes_dir = root / "meta" / "episodes"
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
df = pd.read_parquet(parquet_path)
# Update tasks column
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
df.to_parquet(parquet_path, index=False)
# Write new tasks.parquet
write_tasks(new_task_df, root)
# Update info.json
dataset.meta.info["total_tasks"] = len(unique_tasks)
write_info(dataset.meta.info, root)
# Reload metadata to reflect changes
dataset.meta.tasks = new_task_df
dataset.meta.episodes = load_episodes(root)
logging.info(f"Tasks: {unique_tasks}")
return dataset
def recompute_stats(
dataset: LeRobotDataset,
skip_image_video: bool = True,
delta_action: bool = False,
delta_exclude_joints: list[str] | None = None,
) -> LeRobotDataset:
"""Recompute stats.json from scratch by iterating all episodes.
Args:
dataset: The LeRobotDataset to recompute stats for.
skip_image_video: If True (default), only recompute stats for numeric features
(action, state, etc.) and keep existing image/video stats unchanged.
delta_action: If True, compute action stats as delta (action - state).
Useful when training with use_delta_actions=True so normalization matches.
delta_exclude_joints: Joint names to exclude from delta conversion when
delta_action=True. These dims keep absolute stats. Uses dataset's
action feature names to build the mask. Default: ["gripper"].
Returns:
The same dataset with updated stats.
"""
features = dataset.meta.features
numeric_features = {
k: v for k, v in features.items()
if v["dtype"] not in ["image", "video", "string"]
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
}
if skip_image_video:
features_to_compute = numeric_features
else:
features_to_compute = {
k: v for k, v in features.items()
if v["dtype"] != "string"
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
}
# Build delta mask if delta_action is enabled
delta_mask = None
if delta_action and "action" in features and "observation.state" in features:
if delta_exclude_joints is None:
delta_exclude_joints = ["gripper"]
action_names = features["action"].get("names")
if action_names is not None:
exclude = set(delta_exclude_joints)
delta_mask = [n not in exclude for n in action_names]
else:
action_dim = features["action"]["shape"][0]
delta_mask = [True] * action_dim
# Only recompute action stats when delta is enabled — state stays unchanged
features_to_compute = {"action": features["action"]}
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
else:
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
# Also need state for delta computation even though we don't recompute state stats
needs_state = delta_mask is not None
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
df = pd.read_parquet(parquet_path)
for ep_idx in sorted(df["episode_index"].unique()):
ep_df = df[df["episode_index"] == ep_idx]
episode_data = {}
for key in numeric_keys:
if key in ep_df.columns:
values = ep_df[key].values
if hasattr(values[0], "__len__"):
episode_data[key] = np.stack(values)
else:
episode_data[key] = np.array(values)
# Apply delta conversion to actions before computing stats
if delta_mask is not None and "action" in episode_data:
from lerobot.processor.delta_action_processor import to_delta_actions
# Load state for delta even if we're not computing state stats
if needs_state and "observation.state" in ep_df.columns:
state_values = ep_df["observation.state"].values
if hasattr(state_values[0], "__len__"):
states = np.stack(state_values)
else:
states = np.array(state_values)
actions_t = torch.from_numpy(episode_data["action"]).float()
states_t = torch.from_numpy(states).float()
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
ep_stats = compute_episode_stats(episode_data, features_to_compute)
all_episode_stats.append(ep_stats)
if not all_episode_stats:
logging.warning("No episode stats computed")
return dataset
new_stats = aggregate_stats(all_episode_stats)
# Merge: keep existing stats for features we didn't recompute
if dataset.meta.stats:
for key, value in dataset.meta.stats.items():
if key not in new_stats:
new_stats[key] = value
write_stats(new_stats, dataset.root)
dataset.meta.stats = new_stats
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,

View File

@@ -18,22 +18,30 @@ import contextlib
import logging
import shutil
import tempfile
import time
from collections.abc import Callable
from pathlib import Path
import datasets
import numpy as np
import os
import packaging.version
import pandas as pd
import PIL.Image
import pyarrow as pa
import pyarrow.parquet as pq
from concurrent.futures import ProcessPoolExecutor
import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.compute_stats import (
RunningQuantileStats,
aggregate_stats,
auto_downsample_height_width,
compute_episode_stats,
)
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
@@ -57,6 +65,7 @@ from lerobot.datasets.utils import (
load_info,
load_nested_dataset,
load_stats,
load_subtasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
@@ -67,6 +76,7 @@ from lerobot.datasets.utils import (
write_tasks,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
VideoFrame,
concatenate_video_files,
decode_video_frames,
@@ -78,7 +88,6 @@ from lerobot.datasets.video_utils import (
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
class LeRobotDatasetMetadata:
@@ -162,6 +171,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -418,8 +428,10 @@ class LeRobotDatasetMetadata:
write_info(self.info, self.root)
t0 = time.perf_counter()
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
logging.info(f"[meta.save_episode] aggregate+write_stats: {time.perf_counter() - t0:.2f}s")
def update_video_info(self, video_key: str | None = None) -> None:
"""
@@ -518,6 +530,7 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
@@ -541,13 +554,11 @@ class LeRobotDatasetMetadata:
return obj
def _encode_video_worker(
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
) -> Path:
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
shutil.rmtree(img_dir)
return temp_path
@@ -566,7 +577,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos: bool = True,
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -679,13 +689,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
encoding is CPU-heavy.
"""
super().__init__()
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
@@ -697,7 +702,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_indices = None
self.batch_encoding_size = batch_encoding_size
self.episodes_since_last_encoding = 0
self.vcodec = vcodec
# Unused attributes
self.image_writer = None
@@ -705,6 +709,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.writer = None
self.latest_episode = None
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
self._streaming_encoder = None
self._running_video_stats = {}
self.root.mkdir(exist_ok=True, parents=True)
@@ -935,30 +941,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
return get_hf_features_from_features(self.features)
def _get_query_indices(
self, abs_idx: int, ep_idx: int
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
"""Compute query indices for delta timestamps.
Args:
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
ep_idx: The episode index.
Returns:
A tuple of (query_indices, padding) where:
- query_indices: Dict mapping keys to lists of absolute indices to query
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
"""
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep = self.meta.episodes[ep_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
query_indices = {
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
@@ -1050,12 +1043,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._ensure_hf_dataset_loaded()
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# Use the absolute index from the dataset for delta timestamp calculations
abs_idx = item["index"].item()
query_indices = None
if self.delta_indices is not None:
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
query_indices, padding = self._get_query_indices(idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
@@ -1075,6 +1066,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self.features and self.meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
return item
def __repr__(self):
@@ -1093,6 +1090,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files.
The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo))
"""
if self._streaming_encoder:
self._streaming_encoder.close()
self._close_writer()
self.meta._close_writer()
@@ -1144,6 +1143,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
if frame_index == 0 and self._streaming_encoder:
self._streaming_encoder.start_episode(self.meta.video_keys, self.root)
self._init_running_video_stats()
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
@@ -1157,14 +1159,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
if self._streaming_encoder and self.features[key]["dtype"] == "video":
self._feed_streaming_frame(key, frame[key])
self.episode_buffer[key].append(None)
else:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
else:
self.episode_buffer[key].append(frame[key])
@@ -1215,53 +1221,50 @@ class LeRobotDataset(torch.utils.data.Dataset):
continue
episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
ep_stats = compute_episode_stats(episode_buffer, self.features)
t0 = time.perf_counter()
if self._streaming_encoder:
filtered = {k: v for k, v in episode_buffer.items() if k not in self.meta.video_keys}
ep_stats = compute_episode_stats(filtered, self.features)
for key in self.meta.video_keys:
stats = self._running_video_stats[key].get_statistics()
ep_stats[key] = {
k: v if k == "count" else (v.reshape(-1, 1, 1) / 255.0)
for k, v in stats.items()
}
else:
ep_stats = compute_episode_stats(episode_buffer, self.features)
t_stats = time.perf_counter() - t0
t0 = time.perf_counter()
ep_metadata = self._save_episode_data(episode_buffer)
t_save_data = time.perf_counter() - t0
has_video_keys = len(self.meta.video_keys) > 0
use_batched_encoding = self.batch_encoding_size > 1
if has_video_keys and not use_batched_encoding:
num_cameras = len(self.meta.video_keys)
if parallel_encoding and num_cameras > 1:
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
# num_cameras * num_threads = (total_cpu -1)
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
future_to_key = {
executor.submit(
_encode_video_worker,
video_key,
episode_index,
self.root,
self.fps,
self.vcodec,
): video_key
for video_key in self.meta.video_keys
}
results = {}
for future in concurrent.futures.as_completed(future_to_key):
video_key = future_to_key[future]
try:
temp_path = future.result()
results[video_key] = temp_path
except Exception as exc:
logging.error(f"Video encoding failed for {video_key}: {exc}")
raise exc
for video_key in self.meta.video_keys:
temp_path = results[video_key]
ep_metadata.update(
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
)
else:
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
t0 = time.perf_counter()
if has_video_keys and self._streaming_encoder:
video_paths = self._streaming_encoder.finish_episode()
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_paths[video_key]))
elif has_video_keys and not use_batched_encoding:
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
for video_key, video_path in zip(self.meta.video_keys, video_paths):
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
t_video = time.perf_counter() - t0
# `meta.save_episode` need to be executed after encoding the videos
t0 = time.perf_counter()
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
t_meta = time.perf_counter() - t0
logging.info(
f"[save_episode] ep={episode_index} frames={episode_length} | "
f"stats={t_stats:.2f}s data={t_save_data:.2f}s video={t_video:.2f}s meta={t_meta:.2f}s "
f"total={t_stats + t_save_data + t_video + t_meta:.2f}s"
)
if has_video_keys and use_batched_encoding:
# Check if we should trigger batch encoding
@@ -1429,6 +1432,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index: int,
temp_path: Path | None = None,
) -> dict:
t0 = time.perf_counter()
# Encode episode frames into a temporary video
if temp_path is None:
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
@@ -1502,9 +1506,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
}
save_time = time.perf_counter() - t0
rate = ep_duration_in_s / save_time if save_time > 0 else float("inf")
logging.info(
f"[save_episode_video] {video_key} ep={episode_index} "
f"save={save_time:.2f}s video_dur={ep_duration_in_s:.1f}s "
f"size={ep_size_in_mb:.1f}MB rate={rate:.2f}x realtime"
)
return metadata
def clear_episode_buffer(self, delete_images: bool = True) -> None:
if self._streaming_encoder:
self._streaming_encoder.stop_episode()
# Clean up image files for the current episode buffer
if delete_images:
# Wait for the async image writer to finish
@@ -1513,7 +1526,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index = self.episode_buffer["episode_index"]
if isinstance(episode_index, np.ndarray):
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for cam_key in self.meta.image_keys:
for cam_key in self.meta.camera_keys:
img_dir = self._get_image_file_dir(episode_index, cam_key)
if img_dir.is_dir():
shutil.rmtree(img_dir)
@@ -1546,13 +1559,66 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None:
self.image_writer.wait_until_done()
def start_streaming_encoder(self):
"""Enable streaming video encoding for recording."""
if len(self.meta.video_keys) > 0:
self._streaming_encoder = StreamingVideoEncoder(fps=self.fps)
self._running_video_stats = {}
def _init_running_video_stats(self):
self._running_video_stats = {key: RunningQuantileStats() for key in self.meta.video_keys}
def _feed_streaming_frame(self, key: str, image) -> None:
"""Feed image to streaming encoder and accumulate running stats."""
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[0] in (1, 3, 4):
img_chw = image
else:
img_chw = image.transpose(2, 0, 1)
else:
img_chw = np.array(image).transpose(2, 0, 1)
self._streaming_encoder.feed_frame(key, image)
img_ds = auto_downsample_height_width(img_chw)
c, h, w = img_ds.shape
self._running_video_stats[key].update(
img_ds.transpose(1, 2, 0).reshape(-1, c).astype(np.float64)
)
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index):
temp_paths = []
img_dirs = []
for video_key in video_keys:
temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4")
img_dirs.append(self._get_image_file_dir(episode_index, video_key))
fps = [self.fps]*len(video_keys)
t0 = time.perf_counter()
with ProcessPoolExecutor(max_workers=len(video_keys)) as executor:
executor.map(encode_video_frames,img_dirs,temp_paths,fps)
encode_time = time.perf_counter() - t0
n_frames = len(list(img_dirs[0].glob("*"))) if img_dirs and img_dirs[0].exists() else 0
video_duration_s = n_frames / self.fps if n_frames > 0 else 0
rate = video_duration_s / encode_time if encode_time > 0 else float("inf")
logging.info(
f"[encode_videos] ep={episode_index} keys={len(video_keys)} "
f"encode={encode_time:.2f}s video_dur={video_duration_s:.1f}s "
f"rate={rate:.2f}x realtime"
)
for img_dir in img_dirs:
shutil.rmtree(img_dir)
return temp_paths
@classmethod
def create(
@@ -1568,11 +1634,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0,
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
streaming_encoding: bool = False,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -1589,7 +1653,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_writer = None
obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0
obj.vcodec = vcodec
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
@@ -1607,6 +1670,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.writer = None
obj.latest_episode = None
obj._current_file_start_frame = None
obj._streaming_encoder = None
obj._running_video_stats = {}
if streaming_encoding and len(obj.meta.video_keys) > 0:
obj._streaming_encoder = StreamingVideoEncoder(fps=fps)
# Initialize tracking for incremental recording
obj._lazy_loading = False
obj._recorded_frames = 0

View File

@@ -216,16 +216,17 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
if cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
class ImageTransforms(Transform):

View File

@@ -60,6 +60,7 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.

View File

@@ -16,16 +16,18 @@
import glob
import importlib
import logging
import queue
import shutil
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock
from threading import Lock, Thread
from typing import Any, ClassVar
import av
import fsspec
import numpy as np
import pyarrow as pa
import torch
import torchvision
@@ -310,7 +312,7 @@ def encode_video_frames(
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
overwrite: bool = False,
overwrite: bool = True,
preset: int | None = None,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
@@ -355,6 +357,9 @@ def encode_video_frames(
if crf is not None:
video_options["crf"] = str(crf)
#TEMPORARY FIX
video_options["preset"] = "12"
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
@@ -397,6 +402,141 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
_DONE = object()
class _CameraEncoder:
"""Encodes frames for one camera in a daemon thread."""
def __init__(self, video_path, fps, vcodec, pix_fmt, g, crf):
self.video_path = Path(video_path)
self.fps = fps
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.queue = queue.Queue()
self._thread = None
self._cancelled = False
def start(self):
self.video_path.parent.mkdir(parents=True, exist_ok=True)
self._thread = Thread(target=self._run, daemon=True)
self._thread.start()
def finish(self) -> Path:
self.queue.put(_DONE)
self._thread.join(timeout=120)
return self.video_path
def cancel(self):
self._cancelled = True
while not self.queue.empty():
try:
self.queue.get_nowait()
except queue.Empty:
break
self.queue.put(_DONE)
if self._thread:
self._thread.join(timeout=5)
if self.video_path.parent.exists():
shutil.rmtree(self.video_path.parent, ignore_errors=True)
def _run(self):
options = {}
if self.g is not None:
options["g"] = str(self.g)
if self.crf is not None:
options["crf"] = str(self.crf)
if self.vcodec == "libsvtav1":
options["preset"] = "12"
output = None
output_stream = None
try:
while True:
data = self.queue.get()
if data is _DONE or self._cancelled:
break
if isinstance(data, np.ndarray):
if data.ndim == 3 and data.shape[0] in (1, 3, 4):
data = data.transpose(1, 2, 0)
pil = Image.fromarray(data.astype(np.uint8)).convert("RGB")
else:
pil = data.convert("RGB")
if output is None:
w, h = pil.size
output = av.open(str(self.video_path), "w")
output_stream = output.add_stream(self.vcodec, self.fps, options=options)
output_stream.pix_fmt = self.pix_fmt
output_stream.width = w
output_stream.height = h
pkt = output_stream.encode(av.VideoFrame.from_image(pil))
if pkt:
output.mux(pkt)
if output_stream and not self._cancelled:
pkt = output_stream.encode()
if pkt:
output.mux(pkt)
except Exception as e:
logging.error(f"[StreamingEncoder] {e}")
finally:
if output:
output.close()
class StreamingVideoEncoder:
"""Encodes video on-the-fly using one background thread per camera.
PyAV releases the GIL during encoding, so Python threads give true
parallelism for the CPU-intensive codec work. The queue is unbounded
so feed_frame never blocks the caller (teleop thread always has priority).
"""
def __init__(self, fps, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30):
self.fps = fps
self._vcodec = vcodec
self._pix_fmt = pix_fmt
self._g = g
self._crf = crf
self._encoders: dict[str, _CameraEncoder] = {}
def start_episode(self, video_keys, temp_dir):
self.stop_episode()
for key in video_keys:
path = Path(tempfile.mkdtemp(dir=temp_dir)) / f"{key}_stream.mp4"
enc = _CameraEncoder(path, self.fps, self._vcodec, self._pix_fmt, self._g, self._crf)
enc.start()
self._encoders[key] = enc
def feed_frame(self, video_key, image):
"""Non-blocking: put frame on unbounded queue (never blocks caller)."""
enc = self._encoders.get(video_key)
if enc:
enc.queue.put(image)
def finish_episode(self) -> dict[str, Path]:
"""Flush all encoders, wait for completion, return {key: video_path}."""
paths = {}
for key, enc in self._encoders.items():
paths[key] = enc.finish()
self._encoders.clear()
return paths
def stop_episode(self):
"""Cancel current episode encoding (for re-record)."""
for enc in self._encoders.values():
enc.cancel()
self._encoders.clear()
def close(self):
self.stop_episode()
def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
):

View File

@@ -205,6 +205,7 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False

View File

@@ -112,6 +112,7 @@ class LiberoEnv(gym.Env):
visualization_height: int = 480,
init_states: bool = True,
episode_index: int = 0,
n_envs: int = 1,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
@@ -145,7 +146,9 @@ class LiberoEnv(gym.Env):
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
@@ -295,7 +298,8 @@ class LiberoEnv(gym.Env):
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
self.init_state_id += self._reset_stride # Change init_state_id when reset
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
@@ -373,6 +377,7 @@ def _make_env_fns(
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
n_envs=n_envs,
control_mode=control_mode,
**local_kwargs,
)

View File

@@ -18,4 +18,7 @@ from .motors_bus import (
Motor,
MotorCalibration,
MotorNormMode,
MotorsBus, # Backward compatibility (alias for SerialMotorsBus)
MotorsBusBase,
SerialMotorsBus,
)

View File

@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(groups)
self.group_names = list(self.groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
@@ -230,18 +230,20 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")

View File

@@ -23,17 +23,20 @@ from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
else:
can.Message = object
can.interface = None
class can: # noqa: N801
Message = object
interface = None
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
@@ -152,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase):
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
@@ -159,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
@@ -206,11 +206,34 @@ class DamiaoMotorsBus(MotorsBusBase):
Raises ConnectionError if any motor fails to respond.
"""
logger.info("Starting handshake with motors...")
missing_motors = []
# Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01):
pass
missing_motors = []
for motor_name in self.motors:
msg = self._refresh_motor(motor_name)
if msg is None:
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
# Send enable command
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
# Wait for response with longer timeout
response = None
start_time = time.time()
while time.time() - start_time < 0.1:
response = self.canbus.recv(timeout=0.1)
if response and response.arbitration_id == recv_id:
break
response = None
if response is None:
missing_motors.append(motor_name)
else:
self._process_response(motor_name, msg)
@@ -223,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase):
)
logger.info("Handshake successful. All motors ready.")
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
@@ -230,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
if disable_torque:
try:
@@ -259,7 +281,11 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
@@ -317,7 +343,11 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
@@ -333,6 +363,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
CAN message if received, None otherwise
"""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
start_time = time.time()
messages_seen = []
@@ -371,10 +405,13 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
Dictionary mapping recv_id to CAN message
"""
responses = {}
responses: dict[int, can.Message] = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
@@ -438,8 +475,11 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
@@ -465,6 +505,9 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
@@ -472,7 +515,7 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
@@ -539,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase):
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
@check_if_not_connected
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
@@ -572,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase):
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
@check_if_not_connected
def write(
self,
data_name: str,
@@ -582,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
@@ -633,14 +674,18 @@ class DamiaoMotorsBus(MotorsBusBase):
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
msg = can.Message(
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
# Small delay to reduce bus congestion if necessary, though removed in sync_read previously
# precise_sleep(PRECISE_SLEEP_SEC)
# Collect responses
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
@@ -655,10 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase):
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
@check_if_not_connected
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
@@ -667,6 +714,8 @@ class DamiaoMotorsBus(MotorsBusBase):
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
@@ -676,7 +725,9 @@ class DamiaoMotorsBus(MotorsBusBase):
kd = self._gains[motor]["kd"]
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
msg = can.Message(
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
precise_sleep(PRECISE_TIMEOUT_SEC)
@@ -707,9 +758,9 @@ class DamiaoMotorsBus(MotorsBusBase):
def record_ranges_of_motion(
self,
motors: NameOrID | list[NameOrID] | None = None,
motors: str | list[str] | None = None,
display_values: bool = True,
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
) -> tuple[dict[str, Value], dict[str, Value]]:
"""
Interactively record the min/max values of each motor in degrees.

View File

@@ -15,7 +15,7 @@
"""Configuration tables for Damiao motors."""
from enum import IntEnum
from typing import Dict, List, Tuple
# Motor type definitions
class MotorType(IntEnum):
@@ -33,7 +33,6 @@ class MotorType(IntEnum):
DMH6215 = 11
DMG6220 = 12
# Control modes
class ControlMode(IntEnum):
MIT = 1
@@ -41,7 +40,6 @@ class ControlMode(IntEnum):
VEL = 3
TORQUE_POS = 4
# Motor variable IDs (RID)
class MotorVariable(IntEnum):
UV_VALUE = 0
@@ -90,8 +88,7 @@ class MotorVariable(IntEnum):
P_M = 80
XOUT = 81
# Motor limit parameters [PMAX, VMAX, TMAX]
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
@@ -147,10 +144,10 @@ MODEL_RESOLUTION = {
# CAN baudrates supported by Damiao motors
AVAILABLE_BAUDRATES = [
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
1000000, # 4: 1 mbps (default for OpenArms)
2000000, # 5: 2 mbps
2500000, # 6: 2.5 mbps
@@ -163,6 +160,9 @@ DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 1000
# Data that should be normalized
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
# OpenArms specific configurations
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms has 7 DOF per arm (14 total for dual arm)
@@ -182,14 +182,14 @@ OPENARMS_GRIPPER_MOTOR_IDS = {
# Default motor types for OpenArms
OPENARMS_DEFAULT_MOTOR_TYPES = {
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
}
# MIT control parameter ranges

View File

@@ -22,7 +22,8 @@ import logging
from copy import deepcopy
from enum import Enum
from ..encoding_utils import decode_twos_complement, encode_twos_complement
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
@@ -181,10 +182,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
drive_mode=int(drive_modes[motor]),
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
)
return calibration
@@ -198,15 +199,15 @@ class DynamixelMotorsBus(SerialMotorsBus):
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -235,7 +236,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings = {}
half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -258,6 +259,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return None
return {id_: data[0] for id_, data in data_list.items()}

View File

@@ -17,7 +17,8 @@ from copy import deepcopy
from enum import Enum
from pprint import pformat
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
@@ -126,7 +127,7 @@ class FeetechMotorsBus(SerialMotorsBus):
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
@@ -164,7 +165,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _handshake(self) -> None:
self._assert_motors_exist()
self._assert_same_firmware()
#self._assert_same_firmware()
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
if self.protocol_version == 0:
@@ -262,9 +263,9 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
)
return calibration
@@ -284,7 +285,7 @@ class FeetechMotorsBus(SerialMotorsBus):
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings = {}
half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -292,18 +293,18 @@ class FeetechMotorsBus(SerialMotorsBus):
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry)
self._write(addr, length, motor_id, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
@@ -334,7 +335,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list = {}
data_list: dict[int, int] = {}
status_length = 6
@@ -414,7 +415,7 @@ class FeetechMotorsBus(SerialMotorsBus):
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return None
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:

View File

@@ -23,6 +23,7 @@ from __future__ import annotations
import abc
import logging
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC):
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@@ -179,15 +180,16 @@ class Motor:
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
is_open: bool
baudrate: int
packet_start_time: float
packet_timeout: float
tx_time_per_byte: float
is_using: bool
port_name: str
ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def openPort(self): ...
def closePort(self): ...
@@ -240,19 +242,22 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
port: str
ph: PortHandler
start_address: int
data_length: int
last_result: bool
is_param_changed: bool
param: list
data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
@@ -265,15 +270,17 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
port: str
ph: PortHandler
start_address: int
data_length: int
is_param_changed: bool
param: list
data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
@@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> int:
def _get_motor_model(self, motor: NameOrID) -> str:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
@@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
if motors is None:
return list(self.motors)
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors.copy()
elif isinstance(motors, int):
return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
@@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors.
Args:
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
"""
pass
@contextmanager
def torque_disabled(self, motors: int | str | list[str] | None = None):
def torque_disabled(self, motors: str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase):
"""
pass
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared.
Args:
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
for motor in motors:
for motor in motor_names:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
@@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase):
self.calibration = {}
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one
@@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns:
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
dict[str, Value]: Mapping *motor name → written homing offset*.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
@@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase):
pass
def record_ranges_of_motion(
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[str, Value], dict[str, Value]]:
"""Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase):
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns:
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motors, normalize=False)
positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motors:
for motor in motor_names:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed():
@@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase):
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 3)
move_cursor_up(len(motor_names) + 3)
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return
return None
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return
return None
return model_number
@@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
id_value = self._decode_sign(data_name, {id_: value})
decoded = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
id_value = self._normalize(id_value)
normalized = self._normalize(decoded)
return normalized[id_]
return id_value[id_]
return decoded[id_]
def _read(
self,
@@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase):
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
) -> tuple[int, int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data:
value = self._unnormalize({id_: value})[id_]
int_value = self._unnormalize({id_: value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
@@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase):
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
motors: NameOrID | Sequence[NameOrID] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
@@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase):
Args:
data_name (str): Register name.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase):
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
ids_values, _ = self._sync_read(
raw_ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
ids_values = self._decode_sign(data_name, ids_values)
decoded = self._decode_sign(data_name, raw_ids_values)
if normalize and data_name in self.normalized_data:
ids_values = self._normalize(ids_values)
normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
return {self._id_to_name(id_): value for id_, value in decoded.items()}
def _sync_read(
self,
@@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase):
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in raw_ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data:
ids_values = self._unnormalize(ids_values)
int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._encode_sign(data_name, ids_values)
int_ids_values = self._encode_sign(data_name, int_ids_values)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
def _sync_write(
self,

View File

@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- Either:
@@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig):
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.

View File

@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig):
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@@ -73,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig):
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.

View File

@@ -470,6 +470,13 @@ def make_policy(
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
# Store action feature names for delta_exclude_joints support
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)

View File

@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -50,9 +50,15 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configurations
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,

View File

@@ -21,8 +21,10 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -126,7 +128,13 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
@@ -138,6 +146,7 @@ def make_pi0_pre_post_processors(
padding="max_length",
),
DeviceProcessorStep(device=config.device),
delta_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -149,6 +158,7 @@ def make_pi0_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -50,6 +50,13 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None

View File

@@ -25,7 +25,9 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -129,10 +131,19 @@ def make_pi05_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
delta_step,
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
@@ -154,6 +165,7 @@ def make_pi05_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
max_action_dim: int = 32
max_action_tokens: int = 256
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None

View File

@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions
from lerobot.utils.constants import (
ACTION,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
action_tokens, action_horizon=action_horizon, action_dim=action_dim
)
if self.config.use_delta_actions and OBS_STATE in batch:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
continuous_actions = to_absolute_actions(
continuous_actions, state, [True] * continuous_actions.shape[-1]
)
return continuous_actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:

View File

@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
padding_side="right",
padding="max_length",
),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
ActionTokenizerProcessorStep(
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,

View File

@@ -239,8 +239,10 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
@@ -457,11 +459,10 @@ class SACPolicy(
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self):
"""Set up temperature parameter and initial log_alpha."""
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.temperature = self.log_alpha.exp().item()
class SACObservationEncoder(nn.Module):

View File

@@ -63,12 +63,6 @@ from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -91,8 +85,8 @@ def create_sinusoidal_pos_embedding(
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -100,14 +94,9 @@ def create_sinusoidal_pos_embedding(
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
return pos_emb
def make_att_2d_masks(pad_masks, att_masks):
@@ -386,39 +375,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
if time is None:
time = self.model.sample_time(batch_size, actions.device)
if noise is None:
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
loss_dict["losses_after_forward"] = losses.clone().mean().item()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
@@ -618,9 +596,6 @@ class VLAFlowMatching(nn.Module):
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
@@ -756,10 +731,7 @@ class VLAFlowMatching(nn.Module):
)
time_emb = time_emb.type(dtype=dtype)
if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
@@ -791,12 +763,7 @@ class VLAFlowMatching(nn.Module):
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -859,35 +826,23 @@ class VLAFlowMatching(nn.Module):
num_steps = self.config.num_steps
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
x_t=x_t_cond,
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
timestep=current_timestep,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -898,13 +853,7 @@ class VLAFlowMatching(nn.Module):
execution_horizon=execution_horizon,
)
else:
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
x_t=x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
)
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t

View File

@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
@@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig):
is an alternative to using action repeats. If this is set to more than 1, then we require
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
approach of using multiple steps from the plan is not in the original implementation.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.

View File

@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig):
current step and additional steps going back).
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
action_chunk_size: Action chunk size of each action prediction token.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.

View File

@@ -15,5 +15,7 @@
# limitations under the License.
from .configuration_wall_x import WallXConfig
from .modeling_wall_x import WallXPolicy
from .processor_wall_x import make_wall_x_pre_post_processors
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]

View File

@@ -28,7 +28,14 @@ from .core import (
RobotObservation,
TransitionKey,
)
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .delta_action_processor import (
AbsoluteActionsProcessorStep,
DeltaActionsProcessorStep,
MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep,
to_absolute_actions,
to_delta_actions,
)
from .device_processor import DeviceProcessorStep
from .factory import (
make_default_processors,
@@ -97,6 +104,8 @@ __all__ = [
"make_default_teleop_action_processor",
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"DeltaActionsProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NormalizerProcessorStep",
@@ -126,6 +135,8 @@ __all__ = [
"transition_to_batch",
"TransitionKey",
"TruncatedProcessorStep",
"to_absolute_actions",
"to_delta_actions",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]

View File

@@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(

View File

@@ -14,12 +14,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_STATE
from .core import PolicyAction, RobotAction
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert absolute actions to delta: delta = action - state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] -= state_offset
return actions
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] += state_offset
return actions
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
@@ -141,3 +183,126 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
)
return features
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class DeltaActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to delta actions (action -= state) for masked dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions.
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
the conversion during postprocessing.
Attributes:
enabled: Whether to apply the delta conversion.
exclude_joints: Joint names to keep absolute (not converted to delta).
action_names: Action dimension names from dataset metadata, used to build
the mask from exclude_joints. If None, all dims are converted.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
action_names: list[str] | None = None
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, action_dim: int) -> list[bool]:
if not self.exclude_joints or self.action_names is None:
return [True] * action_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * action_dim
mask = []
for name in self.action_names[:action_dim]:
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
self._last_state = state
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or state is None:
return new_transition
mask = self._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
predicted deltas are converted back to absolute positions for execution.
Reads the cached state from its paired DeltaActionsProcessorStep.
Attributes:
enabled: Whether to apply the absolute conversion.
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
"""
enabled: bool = False
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
if self.delta_step is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
)
if self.delta_step._last_state is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
mask = self.delta_step._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_absolute_actions(
action, self.delta_step._last_state, mask
)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
if ft != FeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
@@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
key=OBS_STATE,
type=FeatureType.STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
new_features[FeatureType.STATE] = state_feats
return new_features

View File

@@ -18,16 +18,18 @@
import math
import time
from dataclasses import dataclass
from typing import Any, Protocol, TypeVar, runtime_checkable
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
if TYPE_CHECKING:
from lerobot.teleoperators.teleoperator import Teleoperator
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import (
ComplementaryDataProcessorStep,
@@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol):
# Type variable constrained to Teleoperator subclasses that also implement events
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
def _check_teleop_with_events(teleop: Teleoperator) -> None:
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
"""
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
@@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
teleop_device: The teleoperator instance to get the action from.
"""
teleop_device: Teleoperator
teleop_device: "Teleoperator"
def complementary_data(self, complementary_data: dict) -> dict:
"""
@@ -312,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
class GripperPenaltyProcessorStep(ProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
@@ -327,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data: dict) -> dict:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data, which should contain
raw joint positions.
transition: The incoming environment transition.
Returns:
A new complementary data dictionary with the `discrete_penalty` key added.
The modified transition with the penalty added to complementary data.
"""
action = self.transition.get(TransitionKey.ACTION)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return complementary_data
return new_transition
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return complementary_data
return new_transition
# Gripper action is a PolicyAction at this stage
gripper_action = action[-1].item()
@@ -362,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Create new complementary data with penalty info
# Update complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
"""

View File

@@ -331,11 +331,9 @@ class _NormalizationMixin:
)
mean, std = stats["mean"], stats["std"]
# Avoid division by zero by adding a small epsilon.
denom = std + self.eps
if inverse:
return tensor * std + mean
return (tensor - mean) / denom
return tensor * (std + 1e-6) + mean
return (tensor - mean) / (std + 1e-6)
if norm_mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min", None)
@@ -367,11 +365,7 @@ class _NormalizationMixin:
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
)
denom = q99 - q01
# Avoid division by zero by adding epsilon when quantiles are identical
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
)
denom = q99 - q01 + 1e-6
if inverse:
return (tensor + 1.0) * denom / 2.0 + q01
return 2.0 * (tensor - q01) / denom - 1.0

View File

@@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
Args:
save_directory: The directory where the pipeline will be saved. If None, saves to
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`.
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
card_kwargs: Additional arguments passed to the card template to customize the card.
config_filename: The name of the JSON configuration file. If None, a name is

Some files were not shown because too many files have changed in this diff Show More