Compare commits

..

354 Commits

Author SHA1 Message Date
Pepijn
ac1c2454c5 Small fixes 2025-07-28 09:13:01 +02:00
Pepijn
2e3c116fad add urdfs 2025-07-28 08:52:37 +02:00
Pepijn
65c174e9f8 fix record and replay and startup issue 2025-07-22 14:07:20 +02:00
Pepijn
005d4bb011 Modify replay 2025-07-18 16:38:09 +02:00
Pepijn
779d38fff0 hack to get images "at" 100fps 2025-07-18 16:33:38 +02:00
Pepijn
c0ffb92735 Update record 2025-07-17 09:56:31 +02:00
Pepijn
baa9b95b97 add acc, vel to dataset 2025-07-17 09:56:23 +02:00
Pepijn
75ce54e212 remove settings add record 2025-07-16 16:06:53 +02:00
Pepijn
05a2316a63 modify gains 2025-07-16 14:26:29 +02:00
Pepijn
2437decd3f Cleanup unneeded code 2025-07-16 10:40:58 +02:00
Pepijn
2d2f5d3d60 remove set_motors 2025-07-15 14:08:41 +02:00
Pepijn
2d608f086a Merge branch 'main' into feat/add-biteleop-so101 2025-07-15 14:03:15 +02:00
Ben Zhang
1c0ac8e341 Parse draccus subclass overrides when using --policy.path (#1501)
* Parse draccus subclass overrides when using --policy.path

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-07-15 12:29:07 +02:00
pre-commit-ci[bot]
c4c0105a47 [pre-commit.ci] pre-commit autoupdate (#1327)
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/adhtruong/mirrors-typos: v1.33.1 → v1.34.0](https://github.com/adhtruong/mirrors-typos/compare/v1.33.1...v1.34.0)
- [github.com/astral-sh/ruff-pre-commit: v0.11.13 → v0.12.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.13...v0.12.3)
- [github.com/woodruffw/zizmor-pre-commit: v1.9.0 → v1.11.0](https://github.com/woodruffw/zizmor-pre-commit/compare/v1.9.0...v1.11.0)
- [github.com/PyCQA/bandit: 1.8.3 → 1.8.6](https://github.com/PyCQA/bandit/compare/1.8.3...1.8.6)

* Ignore B615

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-07-15 12:28:22 +02:00
aka
1b878c9155 fix(record): Improve OpenCV backend handling for Windows systems (#1495)
* fix(record): Improve OpenCV backend handling for Windows systems

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Resolved ruff's E402 error (import statements not at the beginning of the file):
- Moved all import statements to the beginning of the file
- Defined _fix_opencv_backend() as a function
- Adjusted the timing of the fix call
- Code structure conforming to ruff

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(record): Correct OpenCV backend for Windows systems

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(opencv): Set OpenCV environment variable for Windows systems

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(opencv): Refactor MSMF hardware transform environment variable setting for Windows

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-07-15 11:33:02 +02:00
Simon Alibert
724874e063 Fix tests (#1510) 2025-07-15 11:27:01 +02:00
Adil Zouitine
91b110d806 fix(mps): gradient exploding and nan loss issues with ACT (#1490)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-07-15 10:28:19 +02:00
Ben Zhang
519b76110e Remove random noise injected by policy server (#1496) 2025-07-13 21:58:05 +02:00
Pepijn
e925ef3f18 tune 2025-07-11 13:39:54 +02:00
Pepijn
fbf5f04545 Add vel filter and better static friction parameters 2025-07-11 13:34:28 +02:00
Pepijn
9fdec23cee uncomment handshake (issue on this model) 2025-07-11 10:41:22 +02:00
Pepijn
d9af2f1b89 set direction bit 2025-07-11 10:17:55 +02:00
Francesco Capuano
d2645cb19f fix(docs): Record-Upload failed? Don't panic! (#1478)
* fix: add instruction to manually upload dataset

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* fix: repo type is explicited

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-07-10 20:13:56 +02:00
Pepijn
57f7c8b03e Use multi turn, single turn is problem! 2025-07-10 19:31:14 +02:00
Pepijn
e9c795e479 remove set phase 2025-07-10 12:27:43 +02:00
Francesco Capuano
abe51eeba3 Update async docs with blogpost (#1479)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-07-10 12:24:40 +02:00
Pepijn
c9cff132c3 Add better hls table 2025-07-10 10:56:47 +02:00
Francesco Capuano
30c161006d Add Async Inference (#1196)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-07-10 10:39:11 +02:00
Pepijn
0136912fa4 Print more 2025-07-09 19:21:44 +02:00
Adil Zouitine
ce2b9724bf fix(hil-serl): discrete critic send through network (#1468)
Co-authored-by: Khalil Meftah <kmeftah.khalil@gmail.com>
Co-authored-by: jpizarrom <jpizarrom@gmail.com>
2025-07-09 16:22:40 +02:00
Caroline Pascal
cf86b9300d fix(logging): Fixing logging levels (#1466)
* fix(logging): Fixing logging levels, adding custom logging levels for console and file logging

* clean(typing): Adding typing in logging formatter, use proper getter for logging message
2025-07-08 18:59:13 +02:00
Pepijn
67d6bfee78 increase protection current 2025-07-08 15:51:15 +02:00
Simon Alibert
039de254ea Add Hope Jr (#935)
* Fix imports

* Add feetech write tests

* Nit

* Add autoclosing fixture

* Assert ping stub called

* Add CalibrationMode

* Add Motor in dxl robots

* Simplify split_int_bytes

* Rename read/write -> sync_read/write, refactor, add write

* Rename tests

* Refactor dxl tests by functionality

* Add dxl write test

* Refactor _is_comm_success

* Refactor feetech tests by functionality

* Add feetech write test

* Simplify _is_comm_success & _is_error

* Move mock_serial patch to dedicated file

* Remove test skips & fix docstrings

* Nit

* Add dxl operating modes

* Add is_connected in robots and teleops

* Update Koch

* Add feetech operating modes

* Caps dxl OperatingMode

* Update ensure_safe_goal_position

* Update so100

* Privatize methods & renames

* Fix dict

* Add _configure_motors & move ping methods

* Return models (str) with pings

* Implement feetech broadcast ping

* Add raw_values option

* Rename idx -> id_

* Improve errors

* Fix feetech ping tests

* Ensure motors exist at connection time

* Update tests

* Add test_motors_bus

* Move DriveMode & TorqueMode

* Update Koch imports

* Update so100 imports

* Fix visualize_motors_bus

* Fix imports

* Add calibration

* Rename idx -> id_

* Rename idx -> id_

* (WIP) _async_read

* Add new calibration method for robot refactor (#896)

Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>

* Remove deprecated scripts

* Rename CalibrationMode -> MotorNormMode

* Fix calibration functions

* Remove todo

* Add scan_port utility

* Add calibration utilities

* Move encoding functions to encoding_utils

* Add test_encoding_utils

* Rename test

* Add more calibration utilities

* Format baudrate tables

* Implement SO-100 leader calibration

* Implement SO-100 follower calibration

* Implement Koch calibration

* Add test_scan_port (TODO)

* Fix calibration

* Hack feetech firmware bug

* Update tests

* Update Koch & SO-100

* Improve format

* Rename SO-100 classes

* Rename Koch classes

* Add calibration tests

* Remove old calibration tests

* Revert feetech hack and monkeypatch instead

* Simplify motors mocks

* Add is_calibrated test

* Update viperx & widowx

* Rename viperx & widowx

* Remove old calibration

* feat(teleop): thread-safe keyboard teleop implementation (#869)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

* Add support for feetech scs series + various fixes

* Update dynamixel with motors bus & tables changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (WIP) Add Hope Jr

* Rename arm -> hand

* (WIP) Add homonculus arm & glove

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add Feetech protocol version

* Implement read

* Use constants from sdks

* (nit) move write

* Fix broadcast ping type hint

* Add protocol 1 broadcast ping

* Refactor & add _serialize_data

* Add feetech sm8512bl

* Make feetech broadcast ping faster in protocol 1

* Cleanup

* Add support for feetech protocol 1 to _split_into_byte_chunks

* Fix unormalize

* Remove test_motors_bus fixtures

* Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support

* Add more segmented tests (dynamixel)

* Refactor tests

* Add handshake, fix feetech _read_firmware_version

* Fix tests

* Motors config & disconnect fixes

* Add torque_disabled context

* Update branch & fix pre-commit errors

* Fix hand & glove readings

* Update feetech tables

* Move read/write_calibration implementations

* Add setup_motor

* Fix calibration msg display

* Fix setup_motor & add it to robots

* Fix _find_single_motor

* Remove deprecated configure_motor

* Remove deprecated dynamixel_calibration

* Remove names

* Remove deprecated import

* refactor/lekiwi robot (#863)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

* fix(teleoperators): use property is_connected (#1075)

* Remove deprecated manipulator

* Update robot features & naming

* Update teleop features & naming

* Add make_teleoperator_from_config

* Rename find_port

* Fix config parsing

* Remove app script

* Add setup_motors

* Add teleoperate

* Add record

* Add replay

* Fix test_datasets

* Add mock robot & teleop

* Add new test_control_robot

* Add test_record_and_resume

* Remove deprecated scripts & tests

* Add calibrate

* Add docstrings

* Fix tests (no-extras install)

* Add SO101

* Remove pynput from optional deps

* Rename example 7

* Remove unecessary id

* Add MotorsBus docstrings

* Rename arm -> bus

* Remove Moss arm

* Fix setup_motors & calibrate configs

* Fix test_calibrate

* Add copyrights

* Update hand & arm

* Update homonculus hand & arm

* Fix dxl _find_single_motor

* Update glove

* Add setup_motors for lekiwi

* Fix glove calibration

* Complete docstring

* Add check for same min and max during calibration

* Move MockMotorsBus

* Add so100_follower tests

* (WIP) add calibration gui

* Fix test

* Add setup_motors

* Update calibration gui

* Remove old .cache folder

* Replace deprecated abc.abstractproperty

* Fix feetech protocol 1 configure

* Cleanup gui & add copyrights

* Anatomically precise joint names

* (WIP) Add glove to hand joints translation

* Move make_robot_config

* Add drive_mode & norm_mode in glove calibration

* Fix joints translation

* Fix normalization drive_mode

* nit

* Fix glove to hand conversion

* Adapt feetech calibration

* Remove pygame prompt

* Implement arm calibration (hacks)

* Better MotorsBus error messages

* Update feetech read_calibration

* Fix feetech test_is_calibrated

* Cleanup glove

* (WIP) Update arm

* Add changes from #1117

* refactor(cameras): cameras implementations + tests improvements (#1108)

Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix arm joints order

* Add timeout/event logic

* Fix arm & glove

* Fix predict_action from record

* fix(cameras): update docstring + handle sn when starts with 0 + update timeouts to more reasonable value (#1154)

* fix(scripts): parser instead of draccus in record + add __get_path_fields__() to RecordConfig (#1155)

* Left/Right sides + other fixes

* Arm fixes and add config

* More hacks

* Add control scripts

* Fix merge errors

* push changes to calibration, teleop and docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Move readme to docs

* update readme

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* Add files via upload

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* Update image sources

* Symlink doc

* Compress image

* Move image

* Update docs link

* fix docs

* simplify teleop scripts

* fix variable names

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address code review

* add EMA to glove

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* integrate teleoperation for hand

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* import hopejr/homunculus in teleoperate

* update docs for teleoperate, record, replay, train and inference

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chore(hopejr): address comments

* chore(hopejr): address coments 2

* chore(docs): update teleoperation instructions for the hand/glove

* fix(hopejr): calibration int + update docs

---------

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
Signed-off-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: nepyope <nopyeps@gmail.com>
Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-07-08 15:47:11 +02:00
Pepijn
a3feadbbfb Increase max torque limit 2025-07-08 15:26:13 +02:00
Pepijn
25e22ea3ba Add friction to distribiutor estimation 2025-07-08 15:08:35 +02:00
Pepijn
5e27248bba Tune everything a bit 2025-07-08 14:55:34 +02:00
Francesco Capuano
a5e0aae13a Fixes @torch.no_grad() usage (#1455)
* fix: decorator calls with parentheses

* fix no grad for normalize too

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
2025-07-08 13:08:32 +02:00
Pepijn
7f7b45cfbb Add rerun vis 2025-07-04 17:33:25 +02:00
Pepijn
28857dccb1 Add friction component (helps!) but not right one yet 2025-07-04 14:48:31 +02:00
Pepijn
a4d46d4adb modify gains 2025-07-04 14:36:37 +02:00
Pepijn
043b720505 Add inertia 2025-07-04 14:25:14 +02:00
Pepijn
d985f4b1db fix: current impl 2025-07-04 13:20:27 +02:00
Pepijn
ab53de989a fix: current 2025-07-04 09:21:53 +02:00
Pepijn
a56cf87f42 fix gravity compensation 2025-07-02 15:16:58 +02:00
Pepijn
12d1629aae Subtract middle 2025-07-01 18:09:36 +02:00
Pepijn
63e2a2e129 fix: change to actual degrees 2025-07-01 16:32:43 +02:00
Pepijn
2a46f3a53f Merge branch 'main' into feat/add-biteleop-so101 2025-07-01 14:59:26 +02:00
Pepijn
171c355858 Add grav compensation 2025-07-01 14:56:37 +02:00
Pepijn
9ad19d4e81 Add pseudo code for bi teleoperation (4channel) 2025-06-26 18:28:25 +02:00
Pepijn
e171fa788a First bi teleop so101 2025-06-12 09:53:30 +02:00
Simon Alibert
b1386fd79e Disconnect after scan_port 2025-06-04 17:12:30 +02:00
Simon Alibert
b47620cd59 Remove comment 2025-06-04 16:59:44 +02:00
Simon Alibert
a32d988536 Refactor feetech _broadcast_ping 2025-06-04 16:41:33 +02:00
Simon Alibert
9571a713df Refactor record_ranges_of_motion 2025-06-04 14:54:29 +02:00
Pepijn
b418409b24 Fix small issues in docs and refactor (#1194)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-06-04 14:27:57 +02:00
Simon Alibert
0a6b3992ee Fix docstring 2025-06-04 13:16:41 +02:00
pre-commit-ci[bot]
e6d19116c4 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-06-04 11:14:42 +00:00
Simon Alibert
92ea7fc0fb Apply suggestions from code review
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-06-04 13:13:50 +02:00
Simon Alibert
46cd157c55 Dirty fix nightlies 2025-06-04 12:54:09 +02:00
Simon Alibert
52028f5201 Address Michel's comments 2025-06-04 12:47:24 +02:00
Simon Alibert
f5b1ef0045 Remove unused variable 2025-06-04 12:18:54 +02:00
Simon Alibert
81a4deadc3 Address potential None in _assert_same_firmware 2025-06-04 12:17:18 +02:00
Simon Alibert
fef83ce349 Simplify feetech read_calibration 2025-06-04 12:09:48 +02:00
Simon Alibert
eb3986e131 Fix docstring 2025-06-04 11:49:02 +02:00
Simon Alibert
d45226ad06 Remove unused max id 2025-06-04 11:46:10 +02:00
Simon Alibert
fe43f93553 Remove more code 2025-06-04 11:39:19 +02:00
Simon Alibert
40e0a311b5 Remove deprecated code 2025-06-04 11:33:33 +02:00
Simon Alibert
13677cb720 Remove os.name in favor of platform.system() 2025-06-04 11:21:33 +02:00
Simon Alibert
247d493d06 Add TODO 2025-06-03 19:53:25 +02:00
Simon Alibert
2f00475fc6 Fix snippet error 2025-06-03 19:34:06 +02:00
Simon Alibert
4687296d93 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-06-03 19:10:17 +02:00
Simon Alibert
5c2f8ccd14 Remove dead code & cleanup 2025-06-03 18:30:51 +02:00
Simon Alibert
d25e3bd989 Apply suggestions from code review
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-06-03 18:18:44 +02:00
Simon Alibert
adcb07bf62 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-06-02 19:41:50 +02:00
Pepijn
67e3383ffc Refactor LeKiwi (#1136)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-06-02 19:40:48 +02:00
Pepijn
ac5a9b90c7 Update the docs for the robots refactor (#1115)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-06-02 18:14:21 +02:00
Simon Alibert
f35d24a9c3 Cleanup control_utils 2025-06-02 17:09:08 +02:00
Steven Palma
fbdefb2e3e fix: several fixes identified in the docs PR (#1181)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-06-02 16:05:05 +02:00
Simon Alibert
0e39d0f6e6 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-06-02 13:19:26 +02:00
Simon Alibert
b8eecba63d Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-05-28 15:39:30 +02:00
Steven Palma
7308aa57a2 fix(scripts): reconstructs action dict from policy output (#1162)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-05-28 15:36:21 +02:00
Steven Palma
1fd3b2e2db fix(utils): Convert observation values in predict_action to torch.Tensor (#1157)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-05-28 15:29:08 +02:00
Simon Alibert
69e48bbe19 Merge branch 'user/aliberts/2025_02_25_refactor_robots' of github.com:huggingface/lerobot into user/aliberts/2025_02_25_refactor_robots 2025-05-28 15:08:48 +02:00
Steven Palma
0db1a67eaf fix(dataset): key is an action if it starts with such prefix in dataset_to_policy_features (#1156) 2025-05-28 15:08:10 +02:00
Simon Alibert
ccb8468e9b Complete TODO for cameras on robots 'is_connected' 2025-05-28 10:15:19 +02:00
Simon Alibert
f6198d20c6 Add suggestion from Caroline 2025-05-26 17:57:51 +02:00
Simon Alibert
78e29f4f20 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-05-26 10:54:07 +02:00
Steven Palma
fb4bfaf029 fix(scripts): parser instead of draccus in record + add __get_path_fields__() to RecordConfig (#1155) 2025-05-26 10:51:05 +02:00
Steven Palma
809a9c6de0 fix(cameras): update docstring + handle sn when starts with 0 + update timeouts to more reasonable value (#1154) 2025-05-26 10:48:42 +02:00
Simon Alibert
f4c11593d4 Fix predict_action from record 2025-05-24 10:48:06 +02:00
Steven Palma
71e6520cd1 refactor(cameras): cameras implementations + tests improvements (#1108)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-05-23 14:47:37 +02:00
Simon Alibert
a5f15db057 Add changes from #1117 2025-05-23 13:16:14 +02:00
Simon Alibert
edec51988d Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-05-23 13:13:37 +02:00
Simon Alibert
ddca6765b8 Fix feetech test_is_calibrated 2025-05-23 11:46:26 +02:00
Simon Alibert
cedaa83bce Update feetech read_calibration 2025-05-22 17:59:54 +02:00
Simon Alibert
4bb965c283 Better MotorsBus error messages 2025-05-22 17:59:27 +02:00
Simon Alibert
4feaef3436 Adapt feetech calibration 2025-05-22 16:02:55 +02:00
Simon Alibert
e9aac40ba8 nit 2025-05-22 11:34:16 +02:00
Simon Alibert
386ad61007 Fix normalization drive_mode 2025-05-22 11:32:52 +02:00
Simon Alibert
cac4289619 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-05-21 20:19:33 +02:00
Simon Alibert
0bda18eab5 Move make_robot_config 2025-05-21 20:18:47 +02:00
Simon Alibert
8ab2227148 Replace deprecated abc.abstractproperty 2025-05-20 13:16:34 +02:00
Simon Alibert
9dab08dfbc Remove old .cache folder 2025-05-20 09:53:01 +02:00
Simon Alibert
05dfa26c54 Fix test 2025-05-19 11:24:10 +02:00
Simon Alibert
edbba48e81 Add so100_follower tests 2025-05-19 10:58:35 +02:00
Simon Alibert
10119c1a59 Move MockMotorsBus 2025-05-18 11:51:47 +02:00
Simon Alibert
c7ef189da0 Add check for same min and max during calibration 2025-05-16 10:48:45 +02:00
Simon Alibert
51efe6dfee Add setup_motors for lekiwi 2025-05-15 11:46:41 +02:00
Simon Alibert
b0592d9bc8 Fix dxl _find_single_motor 2025-05-14 13:43:36 +02:00
Simon Alibert
363fe64ff9 Add copyrights 2025-05-13 17:38:39 +02:00
Simon Alibert
bbcb12e919 Fix test_calibrate 2025-05-13 17:19:40 +02:00
Simon Alibert
3e87b09d34 Fix setup_motors & calibrate configs 2025-05-13 17:06:24 +02:00
Simon Alibert
81de27dc9a Remove Moss arm 2025-05-13 16:30:50 +02:00
Simon Alibert
eb94a5f03f Rename arm -> bus 2025-05-13 13:26:04 +02:00
Simon Alibert
742708942c Add MotorsBus docstrings 2025-05-13 13:24:46 +02:00
Simon Alibert
5a2f9b6589 Remove unecessary id 2025-05-12 19:01:30 +02:00
Simon Alibert
06f0c579b7 Rename example 7 2025-05-12 18:56:22 +02:00
Simon Alibert
7890767d34 Remove pynput from optional deps 2025-05-12 18:54:08 +02:00
Simon Alibert
d322cb0220 Add SO101 2025-05-11 13:15:28 +02:00
Simon Alibert
f011173ff6 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-05-11 12:53:04 +02:00
Simon Alibert
20129cd4c2 Fix tests (no-extras install) 2025-05-11 12:52:17 +02:00
Simon Alibert
307823bc8d Add docstrings 2025-05-11 12:45:22 +02:00
Simon Alibert
64303781c2 Add calibrate 2025-05-08 18:27:19 +02:00
Simon Alibert
dd3e305164 Remove deprecated scripts & tests 2025-05-08 18:08:38 +02:00
Simon Alibert
cb9cac6a1b Add test_record_and_resume 2025-05-08 17:54:58 +02:00
Simon Alibert
95f9b45418 Add new test_control_robot 2025-05-08 17:38:16 +02:00
Simon Alibert
f9db727647 Add mock robot & teleop 2025-05-08 17:37:49 +02:00
Simon Alibert
230c7fdfab Fix test_datasets 2025-05-08 14:57:12 +02:00
Simon Alibert
320f7e8450 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-05-08 13:24:12 +02:00
Simon Alibert
08fbbb318f Add replay 2025-05-08 13:21:42 +02:00
Simon Alibert
8b98399206 Add record 2025-05-08 13:21:42 +02:00
Simon Alibert
237b14a6ec Add teleoperate 2025-05-08 13:21:42 +02:00
Simon Alibert
2e705ff554 Add setup_motors 2025-05-08 13:21:42 +02:00
Simon Alibert
d72a3f9c32 Remove app script 2025-05-08 13:21:42 +02:00
Simon Alibert
73ac4f38b2 Fix config parsing 2025-05-08 13:21:18 +02:00
Simon Alibert
a0e69dd708 Rename find_port 2025-05-08 13:21:18 +02:00
Simon Alibert
b207babd9e Add make_teleoperator_from_config 2025-05-08 13:21:18 +02:00
Simon Alibert
293870d0f6 Update teleop features & naming 2025-05-08 13:21:17 +02:00
Simon Alibert
87a8cb6d89 Update robot features & naming 2025-05-08 13:20:32 +02:00
Simon Alibert
69dc3f5c9c Remove deprecated manipulator 2025-05-08 13:17:16 +02:00
Steven Palma
e4d4754600 fix(teleoperators): use property is_connected (#1075) 2025-05-07 10:52:44 +02:00
Steven Palma
2e528a8b12 refactor/lekiwi robot (#863)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-04-29 17:48:41 +02:00
Simon Alibert
b7a9b0689a Remove deprecated import 2025-04-18 17:13:08 +02:00
Simon Alibert
b6b9635be6 Remove names 2025-04-18 09:48:16 +02:00
Simon Alibert
21b1026872 Remove deprecated dynamixel_calibration 2025-04-18 09:34:46 +02:00
Simon Alibert
8c3eab32b0 Remove deprecated configure_motor 2025-04-18 09:19:43 +02:00
Simon Alibert
29633865c7 Fix _find_single_motor 2025-04-18 09:18:56 +02:00
Simon Alibert
702749b7d3 Fix setup_motor & add it to robots 2025-04-17 16:56:38 +02:00
Simon Alibert
bf1c737858 Fix calibration msg display 2025-04-17 13:18:32 +02:00
Simon Alibert
d07c7347f8 Add setup_motor 2025-04-17 13:14:06 +02:00
Simon Alibert
57e5e4cc07 Move read/write_calibration implementations 2025-04-16 11:23:33 +02:00
Simon Alibert
2743c29a96 Update feetech tables 2025-04-16 11:01:12 +02:00
Simon Alibert
2bb73ac431 Add torque_disabled context 2025-04-15 11:43:22 +02:00
Simon Alibert
9afc4b771c Motors config & disconnect fixes 2025-04-15 11:20:42 +02:00
Simon Alibert
f71e224023 Fix tests 2025-04-15 11:18:44 +02:00
Simon Alibert
889de7c415 Add handshake, fix feetech _read_firmware_version 2025-04-14 17:14:06 +02:00
Simon Alibert
3539251b18 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-04-14 15:30:35 +02:00
Simon Alibert
1f210bc8a3 Refactor tests 2025-04-14 15:26:29 +02:00
Simon Alibert
d70bc4bde9 Add more segmented tests (dynamixel) 2025-04-14 15:16:38 +02:00
Simon Alibert
bdbca09cb2 Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support 2025-04-14 11:56:53 +02:00
Simon Alibert
e0b292ab51 Remove test_motors_bus fixtures 2025-04-11 12:24:30 +02:00
Simon Alibert
f960f4d8d4 Fix unormalize 2025-04-11 11:58:31 +02:00
Simon Alibert
9e57ec7837 Add support for feetech protocol 1 to _split_into_byte_chunks 2025-04-11 11:58:09 +02:00
Simon Alibert
0a7f51f0da Cleanup 2025-04-11 11:03:09 +02:00
Simon Alibert
4ca92a28e9 Make feetech broadcast ping faster in protocol 1 2025-04-11 11:02:54 +02:00
Simon Alibert
0464dc91b3 Add feetech sm8512bl 2025-04-11 11:02:01 +02:00
Simon Alibert
d32daebf75 Refactor & add _serialize_data 2025-04-11 11:01:12 +02:00
Simon Alibert
27cb0c40bd Add protocol 1 broadcast ping 2025-04-10 17:14:40 +02:00
Simon Alibert
12abc9ca86 Fix broadcast ping type hint 2025-04-10 00:53:17 +02:00
Simon Alibert
4005065223 (nit) move write 2025-04-10 00:51:23 +02:00
Simon Alibert
443fed216c Use constants from sdks 2025-04-10 00:49:03 +02:00
Simon Alibert
42a87e7211 Implement read 2025-04-10 00:35:14 +02:00
Simon Alibert
034171a89a Add Feetech protocol version 2025-04-09 10:26:30 +02:00
pre-commit-ci[bot]
782dff1163 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-08 08:48:18 +00:00
Simon Alibert
8924ccbbab Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-04-08 10:47:40 +02:00
Simon Alibert
792c3d961d Update dynamixel with motors bus & tables changes 2025-04-08 10:47:11 +02:00
Simon Alibert
e998dddcfa Add support for feetech scs series + various fixes 2025-04-08 10:46:29 +02:00
Steven Palma
99c0938b42 feat(teleop): thread-safe keyboard teleop implementation (#869)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-04-04 09:45:18 +02:00
Simon Alibert
716029b1e3 Remove old calibration 2025-04-03 18:42:39 +02:00
Simon Alibert
3848a8f9aa Rename viperx & widowx 2025-04-03 18:37:21 +02:00
Simon Alibert
f7672e14c7 Update viperx & widowx 2025-04-03 18:34:08 +02:00
Simon Alibert
e393af2d88 Add is_calibrated test 2025-04-03 17:35:10 +02:00
Simon Alibert
0dcb2caba8 Simplify motors mocks 2025-04-03 16:43:23 +02:00
Simon Alibert
4679725957 Revert feetech hack and monkeypatch instead 2025-04-03 15:53:54 +02:00
Simon Alibert
57319062aa Remove old calibration tests 2025-04-03 12:17:43 +02:00
Simon Alibert
078f59bfd1 Add calibration tests 2025-04-03 12:14:15 +02:00
Simon Alibert
36fcea2002 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-04-03 08:40:39 +02:00
Simon Alibert
2971bdfed5 Rename Koch classes 2025-04-03 08:23:31 +02:00
Simon Alibert
28cd3a6f3a Rename SO-100 classes 2025-04-03 08:14:35 +02:00
Simon Alibert
c0570b3003 Improve format 2025-04-02 22:40:00 +02:00
Simon Alibert
eeb8490016 Update Koch & SO-100 2025-04-02 22:33:48 +02:00
Simon Alibert
854b78975a Update tests 2025-04-02 22:31:53 +02:00
Simon Alibert
e55d2ffe50 Hack feetech firmware bug 2025-04-02 22:31:45 +02:00
Simon Alibert
1ebd81552c Fix calibration 2025-04-02 22:27:49 +02:00
Simon Alibert
65569ba90e Add test_scan_port (TODO) 2025-03-31 18:40:23 +02:00
Simon Alibert
79293800f1 Implement Koch calibration 2025-03-31 18:18:27 +02:00
Simon Alibert
bc765f9e95 Implement SO-100 follower calibration 2025-03-31 18:17:20 +02:00
Simon Alibert
201311503f Implement SO-100 leader calibration 2025-03-31 18:16:42 +02:00
Simon Alibert
8cc0232e73 Format baudrate tables 2025-03-31 18:14:57 +02:00
Simon Alibert
6bfcc18e73 Add more calibration utilities 2025-03-31 18:14:11 +02:00
Simon Alibert
e096754d14 Rename test 2025-03-31 00:41:25 +02:00
Simon Alibert
02803f545d Add test_encoding_utils 2025-03-31 00:37:28 +02:00
Simon Alibert
8503e8e166 Move encoding functions to encoding_utils 2025-03-31 00:35:31 +02:00
Simon Alibert
d6007c6e7d Add calibration utilities 2025-03-30 15:41:39 +02:00
Simon Alibert
50963fcf13 Add scan_port utility 2025-03-30 15:32:25 +02:00
Simon Alibert
051a52a4ce Remove todo 2025-03-25 21:32:30 +01:00
Simon Alibert
2292b514aa Fix calibration functions 2025-03-25 17:58:54 +01:00
Simon Alibert
1f1a01a798 Rename CalibrationMode -> MotorNormMode 2025-03-25 17:42:18 +01:00
Simon Alibert
faa476f0d2 Remove deprecated scripts 2025-03-25 17:33:05 +01:00
Simon Alibert
5130b69ece Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-03-25 16:25:47 +01:00
Simon Alibert
aed85241b7 Merge branch 'user/aliberts/2025_02_25_refactor_robots' of github.com:huggingface/lerobot into user/aliberts/2025_02_25_refactor_robots 2025-03-25 16:24:40 +01:00
Pepijn
21c3ac42ee Add new calibration method for robot refactor (#896)
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
2025-03-25 16:24:04 +01:00
Simon Alibert
2d3a5fb2be (WIP) _async_read 2025-03-25 15:37:18 +01:00
Simon Alibert
a631e4c11c Rename idx -> id_ 2025-03-25 15:36:36 +01:00
Simon Alibert
222d6f104e Rename idx -> id_ 2025-03-25 14:20:12 +01:00
Simon Alibert
7a3b424cd3 Add calibration 2025-03-25 14:13:55 +01:00
Simon Alibert
af295fadb5 Fix imports 2025-03-25 12:48:58 +01:00
Simon Alibert
9644e2b086 Fix visualize_motors_bus 2025-03-25 12:47:44 +01:00
Simon Alibert
6ccf083127 Update so100 imports 2025-03-25 12:32:38 +01:00
Simon Alibert
bb774e7acd Update Koch imports 2025-03-25 12:31:51 +01:00
Simon Alibert
dcbbeab80b Move DriveMode & TorqueMode 2025-03-25 12:30:07 +01:00
Simon Alibert
b71ac34214 Add test_motors_bus 2025-03-25 12:11:56 +01:00
Simon Alibert
c237d1379e Update tests 2025-03-25 11:12:52 +01:00
Simon Alibert
cf963eb1b0 Ensure motors exist at connection time 2025-03-25 11:12:26 +01:00
Simon Alibert
4293b6a4fb Fix feetech ping tests 2025-03-25 07:26:34 +01:00
Simon Alibert
7a75bb9f61 Improve errors 2025-03-24 21:13:26 +01:00
Simon Alibert
0c1d4cb323 Rename idx -> id_ 2025-03-24 20:58:56 +01:00
Simon Alibert
c6212d585d Add raw_values option 2025-03-24 20:56:58 +01:00
Simon Alibert
7c8ab8e2d6 Implement feetech broadcast ping 2025-03-24 20:46:36 +01:00
Simon Alibert
1de75c46c0 Return models (str) with pings 2025-03-24 20:42:43 +01:00
Simon Alibert
4ad109cff8 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-03-24 13:25:29 +01:00
Simon Alibert
8994252019 Add _configure_motors & move ping methods 2025-03-24 12:19:03 +01:00
Simon Alibert
9832daf08d Fix dict 2025-03-24 12:16:54 +01:00
Simon Alibert
39d8f45810 Privatize methods & renames 2025-03-24 11:57:12 +01:00
Simon Alibert
30fcd3d417 Update so100 2025-03-23 20:15:47 +01:00
Simon Alibert
039b437ef0 Update ensure_safe_goal_position 2025-03-23 19:43:58 +01:00
Simon Alibert
7582a0a2b0 Caps dxl OperatingMode 2025-03-23 19:42:21 +01:00
Simon Alibert
25388d0947 Add feetech operating modes 2025-03-23 19:41:46 +01:00
Simon Alibert
7152bc8aa7 Update Koch 2025-03-23 19:32:26 +01:00
Simon Alibert
5b46dc0b6a Add is_connected in robots and teleops 2025-03-23 19:26:10 +01:00
Simon Alibert
4273f1f384 Add dxl operating modes 2025-03-23 19:25:21 +01:00
Simon Alibert
97194bf7f3 Nit 2025-03-23 17:05:08 +01:00
Simon Alibert
0ac026b521 Remove test skips & fix docstrings 2025-03-23 17:04:30 +01:00
Simon Alibert
ff7cfdaf40 Move mock_serial patch to dedicated file 2025-03-23 17:03:04 +01:00
Simon Alibert
57c97762e1 Simplify _is_comm_success & _is_error 2025-03-23 16:52:29 +01:00
Simon Alibert
a38bb15e79 Add feetech write test 2025-03-23 16:48:32 +01:00
Simon Alibert
3ceaee999d Refactor feetech tests by functionality 2025-03-23 16:25:12 +01:00
Simon Alibert
d485dc1313 Refactor _is_comm_success 2025-03-23 16:15:53 +01:00
Simon Alibert
329d103453 Add dxl write test 2025-03-23 16:12:24 +01:00
Simon Alibert
9f46a3d8f9 Refactor dxl tests by functionality 2025-03-23 16:11:24 +01:00
Simon Alibert
c9ca9e4316 Rename tests 2025-03-23 13:32:08 +01:00
Simon Alibert
5a57e6f4a7 Rename read/write -> sync_read/write, refactor, add write 2025-03-23 13:25:45 +01:00
Simon Alibert
a2f5c34625 Simplify split_int_bytes 2025-03-23 11:55:39 +01:00
Simon Alibert
1f1e1bcfe8 Add Motor in dxl robots 2025-03-23 11:08:20 +01:00
Simon Alibert
e047074825 Add CalibrationMode 2025-03-23 10:20:08 +01:00
Simon Alibert
c2e761437d Assert ping stub called 2025-03-22 18:53:57 +01:00
Simon Alibert
fedac994c3 Add autoclosing fixture 2025-03-22 18:16:13 +01:00
Simon Alibert
7d558d058e Nit 2025-03-22 17:03:18 +01:00
Simon Alibert
1d3e1cbdbd Add feetech write tests 2025-03-22 17:02:01 +01:00
Simon Alibert
0ccc957d5c Fix imports 2025-03-22 16:58:41 +01:00
Simon Alibert
a4d487bc1d Remove comment 2025-03-22 16:52:42 +01:00
Simon Alibert
8ca03a7255 Add dxl write tests 2025-03-22 14:50:05 +01:00
Simon Alibert
f2ed2bfb2f Improve logging & typing 2025-03-22 11:11:39 +01:00
Simon Alibert
40675ec76c Add logger, rm logs 2025-03-22 10:33:42 +01:00
Simon Alibert
9e34c1d731 Move feetech table & cleanup 2025-03-22 01:24:48 +01:00
Simon Alibert
857f335be9 Improve feetech mocking 2025-03-22 01:19:51 +01:00
Simon Alibert
fc4a95f187 Add CRC docstring 2025-03-22 00:50:01 +01:00
Simon Alibert
4fe1880887 Add ping testing 2025-03-22 00:40:22 +01:00
Simon Alibert
6fa859fa19 Improve dynamixel mocking 2025-03-22 00:39:41 +01:00
Simon Alibert
2abfa5838d Improve read ergonomics & typing, rm find_motor_indices 2025-03-22 00:34:07 +01:00
Simon Alibert
3d119c0ccb Add single value write 2025-03-21 12:31:41 +01:00
Simon Alibert
a32081757d Add Motor class 2025-03-21 12:13:44 +01:00
Simon Alibert
56c04ffc53 Move dxl table & cleanup 2025-03-21 11:28:11 +01:00
Simon Alibert
715d4557af Ruff ignore F401 & F403 for init files 2025-03-21 11:22:02 +01:00
Simon Alibert
6541982dff Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-03-20 14:48:19 +01:00
Simon Alibert
43bc9404bb Remove motors from koch teleop config 2025-03-20 14:47:53 +01:00
Simon Alibert
375499c323 Remove set_operating_mode 2025-03-20 14:47:17 +01:00
Simon Alibert
17a4447cef Add debugging init 2025-03-20 14:45:18 +01:00
Simon Alibert
287dc13d96 Remove CLI for calibration visualization + move to debugging 2025-03-20 14:44:23 +01:00
Simon Alibert
02a1cf6a4e Fix calibration visualization 2025-03-20 14:33:36 +01:00
Simon Alibert
34cd1e47bf Remove obsolete test 2025-03-20 14:07:55 +01:00
Simon Alibert
74d56834af Fix dxl calib import 2025-03-20 14:03:11 +01:00
Simon Alibert
dd80dbb4cd Simplify Dxl motors bus import 2025-03-20 14:01:34 +01:00
Simon Alibert
bc020ee0a4 Remove mock_feetech sdk & add feetech new tests 2025-03-20 14:00:10 +01:00
Simon Alibert
a15767aff1 Fix feetech reader/writer 2025-03-20 13:59:00 +01:00
Simon Alibert
9af0a9bf37 Add mock_feetech 2025-03-20 13:58:02 +01:00
Simon Alibert
e2c8bc6948 Fix packet length, remove bytearray for easier debug, improve doctrings 2025-03-20 13:57:15 +01:00
Simon Alibert
2c68c6ca40 Implement FeetechMotorsBus & move functions to calibration 2025-03-20 10:22:47 +01:00
Simon Alibert
dd1f33e5ed Add pytest param ids 2025-03-20 09:44:47 +01:00
Simon Alibert
2c1bb766ff Refactor MockMotors, add return values 2025-03-20 09:40:58 +01:00
Simon Alibert
c1c71fb994 Ignore patching when not on MacOS 2025-03-20 09:38:36 +01:00
Simon Alibert
2d56f35071 Improve formats & docstrings 2025-03-20 09:36:17 +01:00
Simon Alibert
64ce2669ca Add bytes stuffing 2025-03-20 09:33:33 +01:00
Simon Alibert
f527adf7a9 Add mock-serial 2025-03-19 19:03:34 +01:00
Simon Alibert
6a77189f50 Fix import 2025-03-19 19:02:58 +01:00
Simon Alibert
e4a6d035f9 Remove Dxl mock sdk & update tests 2025-03-19 19:02:25 +01:00
Simon Alibert
794f6e00fc Introduce Dxl packet mocking logic 2025-03-19 18:57:29 +01:00
Simon Alibert
97494c6a39 (WIP) Implement Dynamixel 2025-03-19 18:46:04 +01:00
Simon Alibert
9358d334c7 Rewrite MotorsBus 2025-03-19 18:44:05 +01:00
Simon Alibert
c85a9253e7 Move imports 2025-03-15 23:43:26 +01:00
Simon Alibert
8d659a6aa9 Update mock SDKs 2025-03-15 22:26:47 +01:00
Simon Alibert
f6a2396484 Move test_configure_motors_all_ids_1 2025-03-15 22:19:50 +01:00
Simon Alibert
7a7af82e35 Nit docstring 2025-03-15 21:53:42 +01:00
Simon Alibert
7f23972f3f Add Feetech & Dxl basic tests 2025-03-15 21:45:05 +01:00
Simon Alibert
3362b665e6 Move test files 2025-03-15 21:44:01 +01:00
Simon Alibert
eeeccdba53 Update docstrings 2025-03-15 21:42:54 +01:00
Simon Alibert
bd5b181dfd Improve type hints 2025-03-15 21:33:45 +01:00
Simon Alibert
858678786a Remove unused functions 2025-03-15 19:22:40 +01:00
Simon Alibert
0f972661e1 Move imports & remove mock entirely 2025-03-15 19:22:12 +01:00
Simon Alibert
2e9b144c56 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-03-15 13:15:28 +01:00
Simon Alibert
fa8ba9e4e2 Rename set_operating_mode arg 2025-03-15 13:14:29 +01:00
Simon Alibert
2037cc0219 Rename ID -> id 2025-03-15 13:14:05 +01:00
Simon Alibert
5006da72ff Update configure_motor script 2025-03-15 13:13:26 +01:00
Simon Alibert
ad0bacbfe4 Ass model_baudrate_table 2025-03-15 13:11:56 +01:00
Simon Alibert
e33ca2c980 Fix TorqueMode imports 2025-03-15 13:10:57 +01:00
Simon Alibert
f0505e81cc Move common Feetech/Dxl code into MotorsBus base class 2025-03-14 22:00:09 +01:00
Simon Alibert
1f7ddc1d76 New Feetech calibration (#859)
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-03-14 11:31:23 +01:00
Simon Alibert
ce63cfdb25 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-03-13 14:24:50 +01:00
Simon Alibert
d6f1359e69 Remove motors from Koch config 2025-03-12 17:16:09 +01:00
Simon Alibert
2357d4aceb Update base classes typing 2025-03-12 17:15:39 +01:00
Simon Alibert
d6ccdc222c Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-03-10 18:39:48 +01:00
Simon Alibert
9bd0788131 Update paths 2025-03-10 18:34:01 +01:00
Simon Alibert
1ae62c28f7 Move lekiwi files 2025-03-10 18:33:28 +01:00
Simon Alibert
baf6e66c3d Add init files 2025-03-10 18:29:58 +01:00
Simon Alibert
a065bd61ae Add keyboard teleop 2025-03-10 18:28:50 +01:00
Simon Alibert
5dc3c74e64 Add WidowX 2025-03-06 21:31:35 +01:00
Simon Alibert
4214b01703 Add ViperX 2025-03-06 12:53:55 +01:00
Simon Alibert
b974e5541f Update stretch teleop 2025-03-06 11:46:06 +01:00
Simon Alibert
fd64dc84ae Move stretch3 teleop 2025-03-06 10:24:27 +01:00
Simon Alibert
06988b2135 WIP stretch 3 robot & teleop 2025-03-04 13:32:58 +01:00
Simon Alibert
7ed7570b17 WIP Add stretch 2025-03-04 11:42:07 +01:00
Simon Alibert
e2d13ba7e4 Update paths 2025-03-04 11:38:31 +01:00
Simon Alibert
f6c1049474 Update errors 2025-03-04 11:24:05 +01:00
Simon Alibert
2b24feb604 Update constants 2025-03-04 11:07:15 +01:00
Simon Alibert
a13e49073c Add Moss Robot 2025-03-03 20:42:48 +01:00
Simon Alibert
2c7e0f17b6 Add SO-100 teleop 2025-03-03 20:31:04 +01:00
Simon Alibert
418866007e Fixes for Koch robot 2025-03-03 20:19:23 +01:00
Simon Alibert
5ab418dbeb Add feetech calibration 2025-03-03 20:17:54 +01:00
Simon Alibert
95f61ee9d4 Add SO-100 robot 2025-03-03 20:17:18 +01:00
Simon Alibert
ac89c8d226 Add Koch teleop 2025-03-03 18:58:54 +01:00
Simon Alibert
d75d904e43 Add teleoperator base class 2025-03-03 18:55:59 +01:00
Simon Alibert
ea4d8d990c Add Koch robot 2025-03-03 18:53:45 +01:00
Simon Alibert
c93cbb8311 Fix base robot class 2025-03-03 18:49:40 +01:00
Simon Alibert
c0137e89b9 Add calibration dir 2025-03-03 18:44:39 +01:00
Simon Alibert
3111ba78ad Add errors 2025-03-03 18:44:15 +01:00
Simon Alibert
3d3a176940 Move & organize motors, add base class 2025-03-03 18:18:24 +01:00
Simon Alibert
212c6095a2 Move & organize cameras, add base class 2025-03-03 18:16:30 +01:00
Simon Alibert
48469ec674 Move motor files 2025-03-02 21:33:22 +01:00
Simon Alibert
c7dfd32b43 Update DynamixelMotorsBus signature 2025-03-02 21:29:35 +01:00
Simon Alibert
731fb6ebaf Fix import 2025-02-26 19:02:15 +01:00
Simon Alibert
13e124302f Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-02-26 18:49:18 +01:00
Simon Alibert
59bdd29106 Move more files & objects around 2025-02-26 18:48:58 +01:00
Simon Alibert
124829104b Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots 2025-02-26 16:26:03 +01:00
Simon Alibert
21cd2940a9 Reorganize files 2025-02-26 16:22:07 +01:00
110 changed files with 7861 additions and 6588 deletions

View File

@@ -37,7 +37,7 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.33.1
rev: v1.34.0
hooks:
- id: typos
args: [--force-exclude]
@@ -48,7 +48,7 @@ repos:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.13
rev: v0.12.3
hooks:
- id: ruff
args: [--fix]
@@ -62,12 +62,12 @@ repos:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.9.0
rev: v1.11.0
hooks:
- id: zizmor
- repo: https://github.com/PyCQA/bandit
rev: 1.8.3
rev: 1.8.6
hooks:
- id: bandit
args: ["-c", "pyproject.toml"]

View File

@@ -22,6 +22,29 @@
</div>
<h2 align="center">
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
Build Your Own HopeJR Robot!</a></p>
</h2>
<div align="center">
<img
src="media/hope_jr/hopejr.png?raw=true"
alt="HopeJR robot"
title="HopeJR robot"
style="width: 60%;"
/>
<p><strong>Meet HopeJR A humanoid robot arm and hand for dexterous manipulation!</strong></p>
<p>Control it with exoskeletons and gloves for precise hand movements.</p>
<p>Perfect for advanced manipulation tasks! 🤖</p>
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
See the full HopeJR tutorial here.</a></p>
</div>
<br/>
<h2 align="center">
<p><a href="https://huggingface.co/docs/lerobot/so101">
Build Your Own SO-101 Robot!</a></p>

View File

@@ -17,12 +17,16 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: async
title: Use Async Inference
title: "Tutorials"
- sections:
- local: smolvla
title: Finetune SmolVLA
title: "Policies"
- sections:
- local: hope_jr
title: Hope Jr
- local: so101
title: SO-101
- local: so100

272
docs/source/async.mdx Normal file
View File

@@ -0,0 +1,272 @@
# Asynchronous Inference
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
**Try async inference with all the policies** supported by LeRobot!
**What you'll learn:**
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
In a nutshell: with *async inference*, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
---
## Getting started with async inference
You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment.
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
```shell
pip install -e ".[async]"
```
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
You can spin up a policy server running:
```shell
python src/lerobot/scripts/server/policy_server.py \
--host=127.0.0.1 \
--port=8080 \
```
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
```shell
python src/lerobot/scripts/server/robot_client.py \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
```
In summary, you need to specify instructions for:
- `SERVER`: the address and port of the policy server
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
Importantly,
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
Done! You should see your robot moving around by now 😉
---
## Async vs. synchronous inference
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in *idle frames*, frames where the robot awaits idle the policy's output: a new action chunk.
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
With robotics models increasing in size, this problem risks becoming only more severe.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png" width="80%"></img>
</p>
<p align="center"><i>Synchronous inference</i> makes the robot idle while the policy is computing the next chunk of actions.</p>
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
Crucially, with async inference, the next action chunk is computed *before* the current one is exhausted, resulting in no idleness.
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png" width="80%"></img>
</p>
<p align="center"><i>Asynchronous inference</i> results in no idleness because the next chunk is computed before the current chunk is exhausted.</p>
---
## Start the Policy Server
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
<hfoptions id="start_policy_server">
<hfoption id="Command">
```bash
python -m lerobot.scripts.server.policy_server \
--host="localhost" \
--port=8080
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.policy_server import serve
config = PolicyServerConfig(
host="localhost",
port=8080,
)
serve(config)
```
</hfoption>
</hfoptions>
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
---
## Launch the Robot Client
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
<hfoptions id="start_robot_client">
<hfoption id="Command">
```bash
python src/lerobot/scripts/server/robot_client.py \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
```
</hfoption>
<hfoption id="API example">
```python
import threading
from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.scripts.server.helpers import visualize_action_queue_size
# 1. Create the robot instance
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
# these cameras must match the ones expected by the policy
# check the config.json on the Hub for the policy you are using
camera_cfg = {
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
robot_cfg = SO100FollowerConfig(
port="/dev/tty.usbmodem585A0076841",
id="follower_so100",
cameras=camera_cfg
)
# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address="localhost:8080",
policy_device="mps",
policy_type="smolvla",
pretrained_name_or_path="fracapuano/smolvla_async",
chunk_size_threshold=0.5,
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
# 4. Create and start client
client = RobotClient(client_cfg)
# 5. Specify the task
task = "Don't do anything, stay still"
if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()
try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)
```
</hfoption>
</hfoptions>
The following two parameters are key in every setup:
<table>
<thead>
<tr>
<th>Hyperparameter</th>
<th>Default</th>
<th>What it does</th>
</tr>
</thead>
<tbody>
<tr>
<td><code>actions_per_chunk</code></td>
<td>50</td>
<td>How many actions the policy outputs at once. Typical values: 10-50.</td>
</tr>
<tr>
<td><code>chunk_size_threshold</code></td>
<td>0.7</td>
<td>When the queue is ≤ 50% full, the client sends a fresh observation. Value in [0, 1].</td>
</tr>
</tbody>
</table>
<Tip>
Different values of `actions_per_chunk` and `chunk_size_threshold` do result in different behaviours.
</Tip>
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
### Tuning async inference for your setup
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
3. **Adjust `chunk_size_threshold`**.
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png" width="80%"></img>
</p>
<p align="center"><i>The action queue size is plotted at runtime when the `--debug-visualize-queue-size` flag is passed, for various levels of `chunk_size_threshold` (`g` in the SmolVLA paper).</i></p>
---
## Conclusion
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
**Key Takeaways:**
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).

1
docs/source/hope_jr.mdx Symbolic link
View File

@@ -0,0 +1 @@
../../src/lerobot/robots/hope_jr/hope_jr.mdx

View File

@@ -282,6 +282,12 @@ Your dataset will be automatically tagged with `LeRobot` for the community to fi
You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
You can also push your local dataset to the Hub manually, running:
```bash
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
```
#### Record function
The `record` function provides a suite of tools for capturing and managing data during robot operation:

BIN
media/hope_jr/hopejr.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

View File

@@ -46,7 +46,7 @@ classifiers = [
]
dependencies = [
"cmake>=3.29.0.1",
"datasets>=2.19.0",
"datasets>=2.19.0,<=3.6.0",
"deepdiff>=7.0.1",
"diffusers>=0.27.2",
"draccus==0.10.0",
@@ -79,13 +79,14 @@ dependencies = [
[project.optional-dependencies]
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
dora = [
"gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
]
dynamixel = ["dynamixel-sdk>=3.7.31"]
feetech = ["feetech-servo-sdk>=1.0.0"]
gamepad = ["pygame>=2.5.1", "hidapi>=0.14.0"]
hopejr = ["feetech-servo-sdk>=1.0.0", "pygame>=2.5.1"]
kinematics = ["placo>=0.9.6"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
@@ -104,6 +105,7 @@ hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.9", "protobuf>=5.29.3", "grpcio
umi = ["imagecodecs>=2024.1.1"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
async = ["grpcio==1.71.0", "matplotlib>=3.10.3"]
[tool.poetry]
requires-poetry = ">=2.1"
@@ -114,7 +116,7 @@ packages = [
[tool.ruff]
line-length = 110
target-version = "py310"
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py", "*.part", "*.stl"]
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
@@ -131,7 +133,7 @@ exclude_dirs = [
"src/lerobot/policies/pi0/conversion_scripts",
"src/lerobot/scripts/push_dataset_to_hub.py",
]
skips = ["B101", "B311", "B404", "B603"]
skips = ["B101", "B311", "B404", "B603", "B615"]
[tool.typos]
default.extend-ignore-re = [
@@ -146,6 +148,12 @@ default.extend-ignore-identifiers-re = [
"ein",
]
[tool.typos.files]
extend-exclude = [
"*.stl",
"*.part",
]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,175 @@
import math
import sys
import time
from lerobot.robots.so101_follower_torque.config_so101_follower_t import SO101FollowerTConfig
from lerobot.robots.so101_follower_torque.so101_follower_t import SO101FollowerT
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
FRQ = 100
PRINT_HZ = 10
RERUN_HZ = 100
ESC_CLR_EOL = "\x1b[K"
CURSOR_UP = "\x1b[F"
follower_cfg = SO101FollowerTConfig(
port="/dev/tty.usbmodem58760432961",
id="follower_arm_torque",
)
leader_cfg = SO101FollowerTConfig(
port="/dev/tty.usbmodem58760432571",
id="leader_arm_torque",
)
follower = SO101FollowerT(follower_cfg)
leader = SO101FollowerT(leader_cfg)
follower.connect()
leader.connect()
_init_rerun("bilateral_teleoperation")
print("Starting 4-channel bilateral teleoperation")
first_print = True
loop_count = 0
tic_prev = time.perf_counter()
while True:
tic = time.perf_counter()
obs_l, obs_f = leader.get_observation(), follower.get_observation()
dt = tic - tic_prev
tic_prev = tic
if dt <= 0.0:
dt = 0.01 # avoid div-by-zero
tau_cmd_f, tau_cmd_l = [], []
debug_info_f, debug_info_l = {}, {}
pos_f = {j: obs_f[f"{j}.pos"] for j in follower.bus.motors}
vel_f = {j: obs_f[f"{j}.vel"] for j in follower.bus.motors}
tau_reaction_f = {j: obs_f[f"{j}.effort"] for j in follower.bus.motors}
pos_l = {j: obs_l[f"{j}.pos"] for j in leader.bus.motors}
vel_l = {j: obs_l[f"{j}.vel"] for j in leader.bus.motors}
tau_reaction_l = {j: obs_l[f"{j}.effort"] for j in leader.bus.motors}
# Joint-specific control gains
kp_gains = follower.kp_gains
kd_gains = follower.kd_gains
kf_gains = follower.kf_gains
# Compute torque commands
tau_cmd_f = [
kp_gains[j] * (pos_l[j] - pos_f[j]) # Position tracking
+ kd_gains[j] * (vel_l[j] - vel_f[j]) # Velocity damping
+ kf_gains[j] * (-tau_reaction_l[j] - tau_reaction_f[j]) # Force reflection
for j in follower.bus.motors
]
tau_cmd_l = [
kp_gains[j] * (pos_f[j] - pos_l[j]) # Position tracking
+ kd_gains[j] * (vel_f[j] - vel_l[j]) # Velocity damping
+ kf_gains[j] * (-tau_reaction_f[j] - tau_reaction_l[j]) # Force reflection
for j in leader.bus.motors
]
# Store debug info
for i, j in enumerate(follower.bus.motors):
debug_info_f[j] = {
"τ_reaction": tau_reaction_f[j],
"τ_ref": tau_cmd_f[i],
"θ_err": pos_l[j] - pos_f[j],
"ω_err": vel_l[j] - vel_f[j],
"τ_err": -tau_reaction_l[j] - tau_reaction_f[j],
}
debug_info_l[j] = {
"τ_reaction": tau_reaction_l[j],
"τ_ref": tau_cmd_l[i],
"θ_err": pos_f[j] - pos_l[j],
"ω_err": vel_f[j] - vel_l[j],
"τ_err": -tau_reaction_f[j] - tau_reaction_l[j],
}
# Send torques to both arms
follower.send_action({f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(follower.bus.motors)})
leader.send_action({f"{m}.effort": tau_cmd_l[i] for i, m in enumerate(leader.bus.motors)})
observation = {
"follower_joint_angles": pos_f, # θ_f: current angles
"follower_angular_velocities": vel_f, # ω_f: current velocities
"follower_external_torques": tau_reaction_f, # τ_ext: measured minus deterministic components
}
action = {
"leader_target_angles": pos_l, # θ_leader[τ]: absolute target angles
"leader_target_velocities": vel_l, # ω_leader[τ]: absolute target velocities
"leader_interaction_torques": tau_reaction_l, # τ_leader[τ]: cmd minus deterministic components
}
if loop_count % (FRQ // RERUN_HZ) == 0:
log_rerun_data(observation, action)
loop_count += 1
if loop_count % (FRQ // PRINT_HZ) == 0:
hz = 1.0 / dt
lines = [f"Loop {hz:6.1f} Hz Δt {dt * 1e3:5.2f} ms"]
lines.append("=" * 106)
lines.append("LEADER ARM TORQUE ANALYSIS:")
lines.append(f"{'Joint':<13}{'Pos':>8}{'React':>6}{'Cmd':>6}")
lines.append(f"{'':13}{'(deg)':>8}{'(Nm)':>6}{'(Nm)':>6}")
lines.append("-" * 86)
for i, j in enumerate(leader.bus.motors):
debug_l = debug_info_l[j]
lines.append(
f"{j:<13s}{math.degrees(pos_l[j]):+8.1f}{debug_l['τ_reaction']:+6.2f}{tau_cmd_l[i]:+6.2f}"
)
lines.append("")
lines.append("FOLLOWER ARM TORQUE ANALYSIS:")
lines.append(f"{'Joint':<13}{'Pos':>8}{'React':>6}{'Cmd':>6}")
lines.append(f"{'':13}{'(deg)':>8}{'(Nm)':>6}{'(Nm)':>6}")
lines.append("-" * 86)
for i, j in enumerate(follower.bus.motors):
debug_f = debug_info_f[j]
lines.append(
f"{j:<13s}{math.degrees(pos_f[j]):+8.1f}{debug_f['τ_reaction']:+6.2f}{tau_cmd_f[i]:+6.2f}"
)
lines.append("")
lines.append("=" * 86)
lines.append("TORQUE COMPONENT EXPLANATIONS:")
lines.append("• Pos (joint pos) = Joint position in degrees")
lines.append("• React (reaction) = External forces (human interaction, contact)")
lines.append("• Meas (measured) = Raw torque from motor current sensor")
lines.append("• Cmd (command) = Final torque sent to motor")
lines.append("-" * 86)
lines.append(
"Cmd = Track + Vel + Force + (Added as feedforward in send_action: Grav + Inert + Frict)"
)
lines.append("React = Meas - Grav - Inert - Frict (external forces)")
lines.append("Force = Kf × (reflect_other_robot - React) (telepresence)")
lines.append("Frict = b_visc×ω + f_coulomb×sign(ω) (transparency)")
lines.append(
f"Joint Gains: shoulder_pan Kp={kp_gains['shoulder_pan']:.1f} | shoulder_pan Kd={kd_gains['shoulder_pan']:.1f} | shoulder_pan Kf={kf_gains['shoulder_pan']:.1f}"
)
lines.append(
f"Friction Comp, Viscous: {follower.friction_viscous['shoulder_pan']:.3f} | Coulomb: {follower.friction_coulomb['shoulder_pan']:.3f} (robot-class)"
)
block = "\n".join(lines)
if first_print:
sys.stdout.write(block + "\n")
first_print = False
else:
sys.stdout.write(CURSOR_UP * len(lines) + ESC_CLR_EOL + block + "\n")
sys.stdout.flush()
busy_wait(max(0.0, 1.0 / FRQ - (time.perf_counter() - tic)))

View File

@@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
hope_jr,
koch_follower,
lekiwi,
make_robot_from_config,
@@ -45,6 +46,7 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
homunculus,
koch_leader,
make_teleoperator_from_config,
so100_leader,

View File

@@ -18,12 +18,16 @@ Provides the OpenCVCamera class for capturing frames from cameras using OpenCV.
import logging
import math
import os
import platform
import time
from pathlib import Path
from threading import Event, Lock, Thread
from typing import Any, Dict, List
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
@@ -108,7 +112,8 @@ class OpenCVCamera(Camera):
self.config = config
self.index_or_path = config.index_or_path
self.fps = config.fps
self.wanted_fps = config.fps
self.camera_fps = None
self.color_mode = config.color_mode
self.warmup_s = config.warmup_s
@@ -196,10 +201,9 @@ class OpenCVCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
# We don't set the FPS. We GET the actual (max) FPS from the camera.
self.camera_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
logger.info(f"{self} is running at its default/max FPS: {self.camera_fps:.2f}")
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
@@ -312,19 +316,23 @@ class OpenCVCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
# Start the background capture thread if it's not running
if self.thread is None or not self.thread.is_alive():
# Perform an initial blocking read to populate the first frame
ret, frame = self.videocapture.read()
if not ret or frame is None:
raise RuntimeError(f"{self} failed to read initial frame.")
ret, frame = self.videocapture.read()
self.latest_frame = self._postprocess_image(frame)
self._start_read_thread()
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
with self.frame_lock:
frame = self.latest_frame
processed_frame = self._postprocess_image(frame, color_mode)
if frame is None:
raise RuntimeError(f"Internal error: Read thread started but no frame is available for {self}.")
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return processed_frame
return frame.copy()
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
"""
@@ -382,16 +390,23 @@ class OpenCVCamera(Camera):
"""
while not self.stop_event.is_set():
try:
color_image = self.read()
ret, frame = self.videocapture.read()
if not ret or frame is None:
logger.warning(f"Failed to read frame in background for {self}.")
time.sleep(0.01)
continue
processed_frame = self._postprocess_image(frame)
with self.frame_lock:
self.latest_frame = color_image
self.latest_frame = processed_frame
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
if not self.is_connected:
break
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""

View File

@@ -60,6 +60,8 @@ def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return cv2.CAP_AVFOUNDATION
else:
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return cv2.CAP_ANY

File diff suppressed because it is too large Load Diff

View File

@@ -37,21 +37,6 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
# Multi-dataset support
sampling_weights: str | None = None
max_action_dim: int | None = None
max_state_dim: int | None = None
max_num_images: int | None = None
max_image_dim: int | None = None
train_on_all_features: bool = False
features_version: int = 0
discard_first_n_frames: int = 0
min_fps: int = 1
max_fps: int = 100
discard_first_idle_frames: bool = False
motion_threshold: float = 5e-2
motion_window_size: int = 10
motion_buffer: int = 3
@dataclass

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type, TypeVar
@@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
# HACK: Parse the original config to get the config subclass, so that we can
# apply cli overrides.
# This is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type)
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_overrides)
orig_config = draccus.parse(cls, config_file, args=[])
with open(config_file) as f:
config = json.load(f)
config.pop("type")
with tempfile.NamedTemporaryFile("w+") as f:
json.dump(config, f)
config_file = f.name
f.flush()
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)

View File

@@ -22,16 +22,15 @@ OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGE_3 = "observation.image3"
OBS_IMAGE_4 = "observation.image4"
REWARD = "next.reward"
ROBOTS = "robots"
TASK = "task"
ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators"
ROBOTS = "robots"
TELEOPERATORS = "teleoperators"
# files & directories
CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"

View File

@@ -1,68 +0,0 @@
from typing import Dict, List
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
def pad_tensor(
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
) -> torch.Tensor:
is_numpy = isinstance(tensor, np.ndarray)
if is_numpy:
tensor = torch.tensor(tensor)
pad = max_size - tensor.shape[pad_dim]
if pad > 0:
pad_sizes = (0, pad) # pad right
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
return tensor.numpy() if is_numpy else tensor
def pad_list_of_tensors(
tensors: List[torch.Tensor], pad_dim: int = -1, pad_value: float = 0.0
) -> List[torch.Tensor]:
max_size = max([v.shape[pad_dim] for v in tensors])
return [pad_tensor(tensor, max_size, pad_dim=pad_dim, pad_value=pad_value) for tensor in tensors]
def multidataset_collate_fn(
batch: List[Dict[str, torch.Tensor]],
pad_dim: int = -1,
pad_value: float = 0.0,
keys_to_max_dim: dict = {},
) -> Dict[str, torch.Tensor]:
"""
Custom collate function to pad tensors with multiple dimensions.
Args:
batch (List[Dict[str, torch.Tensor]]): List of dataset samples (each sample is a dictionary).
Returns:
Dict[str, torch.Tensor]: Batch with padded tensors.
"""
batch_keys = batch[0].keys()
collated_batch = [{} for _ in range(len(batch))]
# FIXME(mshukor): pad to max shape per feature type
for key in batch_keys:
values = [sample[key] for sample in batch]
if (
key in keys_to_max_dim
and isinstance(values[0], torch.Tensor)
and is_batch_need_padding(values, pad_dim=pad_dim)
and keys_to_max_dim[key] is not None
):
max_size = keys_to_max_dim[key]
for i in range(len(batch)):
collated_batch[i][key] = pad_tensor(
batch[i][key], max_size, pad_dim=pad_dim, pad_value=pad_value
)
else:
for i in range(len(batch)):
collated_batch[i][key] = batch[i][key]
collated_batch = default_collate(collated_batch)
return collated_batch

View File

@@ -125,30 +125,9 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
# Filter out stats that don't have required keys
valid_stats = []
for s in stats_ft_list:
if all(key in s for key in ["mean", "std", "count", "min", "max"]):
valid_stats.append(s)
else:
# If count is missing, add it with a default value
if "count" not in s:
s["count"] = np.array([1]) # Default count
valid_stats.append(s)
if not valid_stats:
# If no valid stats, return empty stats
return {
"min": np.array([0]),
"max": np.array([0]),
"mean": np.array([0]),
"std": np.array([0]),
"count": np.array([0]),
}
means = np.stack([s["mean"] for s in valid_stats])
variances = np.stack([s["std"] ** 2 for s in valid_stats])
counts = np.stack([s["count"] for s in valid_stats])
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
total_count = counts.sum(axis=0)
# Prepare weighted mean by matching number of dimensions
@@ -165,8 +144,8 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
total_variance = weighted_variances.sum(axis=0) / total_count
return {
"min": np.min(np.stack([s["min"] for s in valid_stats]), axis=0),
"max": np.max(np.stack([s["max"] for s in valid_stats]), axis=0),
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"count": total_count,

View File

@@ -32,8 +32,6 @@ IMAGENET_STATS = {
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
from lerobot.datasets.utils_must import EPISODES_DATASET_MAPPING, FEATURE_KEYS_MAPPING
def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
@@ -83,77 +81,35 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
if "," in cfg.dataset.repo_id:
repo_id = cfg.dataset.repo_id.split(",")
repo_id = [r for r in repo_id if r]
else:
repo_id = cfg.dataset.repo_id
sampling_weights = cfg.dataset.sampling_weights.split(",") if cfg.dataset.sampling_weights else None
feature_keys_mapping = FEATURE_KEYS_MAPPING
if isinstance(repo_id, str):
revision = getattr(cfg.dataset, "revision", None)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id,
feature_keys_mapping=feature_keys_mapping,
revision=revision,
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=getattr(cfg.dataset, "root", None),
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=revision,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
download_videos=True,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.dataset.max_action_dim,
max_state_dim=cfg.dataset.max_state_dim,
max_num_images=cfg.dataset.max_num_images,
max_image_dim=cfg.dataset.max_image_dim,
)
else:
delta_timestamps = {}
episodes = {}
for i in range(len(repo_id)):
ds_meta = LeRobotDatasetMetadata(
repo_id[i],
feature_keys_mapping=feature_keys_mapping,
) # FIXME(mshukor): ?
delta_timestamps[repo_id[i]] = resolve_delta_timestamps(cfg.policy, ds_meta)
episodes[repo_id[i]] = EPISODES_DATASET_MAPPING.get(repo_id[i], cfg.dataset.episodes)
# training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None)
# FIXME: (jadechoghari): check support for training features
training_features = None
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
repo_id,
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
episodes=episodes,
delta_timestamps=delta_timestamps,
# delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
download_videos=True,
sampling_weights=sampling_weights,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.policy.max_action_dim,
max_state_dim=cfg.policy.max_state_dim,
max_num_images=cfg.dataset.max_num_images,
max_image_dim=cfg.dataset.max_image_dim,
train_on_all_features=cfg.dataset.train_on_all_features,
training_features=training_features,
discard_first_n_frames=cfg.dataset.discard_first_n_frames,
min_fps=cfg.dataset.min_fps,
max_fps=cfg.dataset.max_fps,
discard_first_idle_frames=cfg.dataset.discard_first_idle_frames,
motion_threshold=cfg.dataset.motion_threshold,
motion_window_size=cfg.dataset.motion_window_size,
motion_buffer=cfg.dataset.motion_buffer,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
)
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import contextlib
import logging
import os
import shutil
from pathlib import Path
from typing import Callable
@@ -31,16 +30,8 @@ from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.constants import (
ACTION,
HF_LEROBOT_HOME,
OBS_ENV_STATE,
OBS_STATE,
)
from lerobot.datasets.compute_stats import ( # aggregate_stats_per_robot_type,
aggregate_stats,
compute_episode_stats,
)
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
DEFAULT_FEATURES,
@@ -50,6 +41,7 @@ from lerobot.datasets.utils import (
_validate_feature_names,
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
create_empty_dataset_info,
@@ -66,34 +58,12 @@ from lerobot.datasets.utils import (
load_info,
load_stats,
load_tasks,
map_dict_keys,
validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
write_json,
# keep_datasets_with_the_same_features_per_robot_type,
# map_dict_pad_keys,
# keep_datasets_with_valid_fps,
# find_start_of_motion,
)
# mustafa stuff here
from lerobot.datasets.utils_must import (
OBS_IMAGE,
OBS_IMAGE_2,
OBS_IMAGE_3,
ROBOT_TYPE_KEYS_MAPPING,
TASKS_KEYS_MAPPING,
aggregate_stats_per_robot_type,
create_padded_features,
find_start_of_motion,
keep_datasets_with_the_same_features_per_robot_type,
keep_datasets_with_valid_fps,
map_dict_keys,
pad_tensor,
reshape_features_to_max_dim,
)
from lerobot.datasets.video_utils import (
VideoFrame,
@@ -104,15 +74,6 @@ from lerobot.datasets.video_utils import (
)
CODEBASE_VERSION = "v2.1"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
for t in range(len(velocities) - window_size):
window_mean = velocities[t : t + window_size].mean()
if window_mean > threshold:
return max(0, t - motion_buffer) # include slight context before motion
return 0
class LeRobotDatasetMetadata:
@@ -120,13 +81,10 @@ class LeRobotDatasetMetadata:
self,
repo_id: str,
root: str | Path | None = None,
local_files_only: bool = False,
feature_keys_mapping: dict[str, str] | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
):
self.repo_id = repo_id
self.local_files_only = local_files_only
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
@@ -141,14 +99,6 @@ class LeRobotDatasetMetadata:
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
# added by mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
self.info["features"] = map_dict_keys(
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
)
def load_metadata(self):
self.info = load_info(self.root)
@@ -227,15 +177,7 @@ class LeRobotDatasetMetadata:
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
# changed
keys = []
for key, ft in self.features.items():
key_ = (
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
)
if ft["dtype"] == "video":
keys.append(key_)
return keys
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
@@ -400,18 +342,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
download_videos: bool = True,
video_backend: str | None = None,
# new thing by M
feature_keys_mapping: dict[str, str] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
training_features: list | None = None,
discard_first_n_frames: int = 0,
discard_first_idle_frames: bool = False,
motion_threshold: float = 5e-2,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -525,34 +455,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
# by mshukor
self.training_features = training_features
self.discard_first_n_frames = discard_first_n_frames
self.discard_first_idle_frames = discard_first_idle_frames
self.motion_threshold = motion_threshold
self.motion_window_size = motion_window_size
self.motion_buffer = motion_buffer
# Unused attributes
self.image_writer = None
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
# more mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
# Load metadata
# TODO: change
self.meta = LeRobotDatasetMetadata(
self.repo_id,
self.root,
self.revision,
force_cache_sync=force_cache_sync,
feature_keys_mapping=feature_keys_mapping,
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
@@ -571,74 +482,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# mustafa code
if self.discard_first_n_frames > 0:
print("Discarding first n frames:", self.discard_first_n_frames)
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
# TODO implement advanced strategy
self.subset_frame_ids += [
frame_idx for frame_idx in range(from_ + int(self.fps * self.discard_first_n_frames), to_)
]
elif self.discard_first_idle_frames:
print(
f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}"
)
self.robot_states = torch.stack(self.hf_dataset[OBS_STATE]).numpy() # shape: [T, D]
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
ep_states = self.robot_states[from_:to_]
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
velocities = np.concatenate([[0.0], velocities])
start_idx = find_start_of_motion(
velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer
)
self.subset_frame_ids += list(range(from_ + start_idx, to_))
# Check timestamps
# commented TODO: check why
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
# TODO: check why commented
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Mustafa
self.meta.info["features"] = map_dict_keys(
self.meta.info["features"],
feature_keys_mapping=self.feature_keys_mapping,
training_features=self.training_features,
)
self.keys_to_max_dim = {
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
}
self.meta.info["features"] = reshape_features_to_max_dim(
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
self.meta.stats = map_dict_keys(
self.meta.stats,
feature_keys_mapping=self.feature_keys_mapping,
training_features=self.training_features,
)
self.robot_type = self.meta.info.get("robot_type", "")
# Override tasks
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
def push_to_hub(
self,
branch: str | None = None,
@@ -793,7 +647,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
# FIXME(mshukor): what if we train on multiple datasets with different features
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
@@ -817,21 +670,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
# TODO: changed by mustafa
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
queries = {}
for key, q_idx in query_indices.items():
if (
key not in self.meta.video_keys
and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys
):
key_ = (
self.inverse_feature_keys_mapping.get(key, key)
if self.inverse_feature_keys_mapping
else key
)
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
return queries
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -855,12 +699,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __len__(self):
return self.num_frames
# changed by mshukor
def __getitem__(self, idx) -> dict:
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
idx = self.subset_frame_ids[idx]
item = self.hf_dataset[idx]
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
ep_idx = item["episode_index"].item()
query_indices = None
@@ -877,27 +717,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
try:
item["task"] = self.meta.tasks[task_idx]
except:
print(self.meta.tasks, task_idx, self.repo_id)
if "robot_type" not in item:
item["robot_type"] = self.robot_type
item = map_dict_keys(
item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
)
# Add padded features
# item = self._add_padded_features(item, self.training_features)
if self.image_transforms is not None:
for cam in item:
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
item[cam] = self.image_transforms(item[cam])
# Map pad keys
# print(item.keys(), "before")
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
# print(item.keys())
item["task"] = self.meta.tasks[task_idx]
return item
def __repr__(self):
@@ -1157,7 +985,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
@@ -1178,106 +1005,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return obj
class MultiLeRobotDatasetMeta:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
keys_to_max_dim: dict[str, int],
train_on_all_features: bool = False,
):
self.repo_ids = repo_ids
self.keys_to_max_dim = keys_to_max_dim
self.train_on_all_features = train_on_all_features
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
# assign robot_type if missing
for ds in datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# step 1: compute disabled features
self.disabled_features = set()
if not self.train_on_all_features:
intersection = set(datasets[0].features)
for ds in datasets:
intersection.intersection_update(ds.features)
if not intersection:
raise RuntimeError("No common features across datasets.")
for repo_id, ds in zip(repo_ids, datasets, strict=False):
extra = set(ds.features) - intersection
logging.warning(f"Disabling {extra} for repo {repo_id}")
self.disabled_features.update(extra)
# step 2: build union_features excluding disabled
self.union_features = {}
for ds in datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
self.union_features[k] = v
# step 3: reshape feature schema
self.features = reshape_features_to_max_dim(
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
# step 4: aggregate stats
self.stats = aggregate_stats_per_robot_type(datasets)
for robot_type_, stats_ in self.stats.items():
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
pad_value = 0 if k in ["min", "mean"] else 1
self.stats[robot_type_][feat_key][k] = pad_tensor(
v,
max_size=self.keys_to_max_dim.get(feat_key, -1),
pad_dim=-1,
pad_value=pad_value,
)
# step 5: episodes & tasks
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
class MultiLeRobotDatasetCleaner:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
sampling_weights: list[float],
datasets_repo_ids: list[str],
min_fps: int = 1,
max_fps: int = 100,
):
self.original_datasets = datasets
self.original_repo_ids = repo_ids
self.original_weights = sampling_weights
self.original_datasets_repo_ids = datasets_repo_ids
# step 1: remove datasets with invalid fps
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
# step 2: keep datasets with same features per robot type
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
valid_fps_datasets
)
self.cleaned_datasets = consistent_datasets
self.keep_mask = keep_mask
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_datasets_repo_ids = [
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
]
self.cumulative_sizes = np.array(
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
)
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
@@ -1294,24 +1021,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
# add
sampling_weights: list[float] | None = None,
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
train_on_all_features: bool = False,
training_features: list | None = None,
discard_first_n_frames: int = 0,
min_fps: int = 1,
max_fps: int = 100,
discard_first_idle_frames: bool = False,
motion_threshold: float = 0.05,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
super().__init__()
self.repo_ids = repo_ids
@@ -1319,89 +1029,46 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
_datasets = []
datasets_repo_ids = []
self.sampling_weights = []
self.training_features = training_features
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
assert len(sampling_weights) == len(repo_ids), (
"The number of sampling weights must match the number of datasets. "
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
)
for i, repo_id in enumerate(repo_ids):
try:
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
_datasets.append(
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes.get(repo_id, None) if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
feature_keys_mapping=feature_keys_mapping,
training_features=training_features,
discard_first_n_frames=discard_first_n_frames,
discard_first_idle_frames=discard_first_idle_frames,
motion_threshold=motion_threshold,
motion_window_size=motion_window_size,
motion_buffer=motion_buffer,
)
)
datasets_repo_ids.append(repo_id)
self.sampling_weights.append(float(sampling_weights[i]))
except Exception as e:
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
print(
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
)
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
# FIXME(mshukor): apply mapping to unify used keys
# FIXME(mshukor): pad based on types in case we have more than one state?
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = (
delta_timestamps.get(repo_id, None) if delta_timestamps else None
) # delta_timestamps # FIXME(mshukor): last repo?
# In case datasets with the same robot_type have different features
cleaner = MultiLeRobotDatasetCleaner(
datasets=_datasets,
repo_ids=repo_ids,
sampling_weights=self.sampling_weights,
datasets_repo_ids=datasets_repo_ids,
min_fps=min_fps,
max_fps=max_fps,
)
self._datasets = cleaner.cleaned_datasets
self.sampling_weights = cleaner.cleaned_weights
self.repo_ids = cleaner.cleaned_repo_ids
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
self.cumulative_sizes = cleaner.cumulative_sizes
# self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
# self.meta.info = {
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
# }
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
self.meta = MultiLeRobotDatasetMeta(
datasets=self._datasets,
repo_ids=self.repo_ids,
keys_to_max_dim={
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
},
train_on_all_features=train_on_all_features,
)
self.disabled_features = self.meta.disabled_features
self.stats = self.meta.stats
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):
@@ -1489,14 +1156,23 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
item = self._datasets[dataset_idx][local_idx]
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
item = create_padded_features(item, self.meta.features)
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]
return item
def __repr__(self):

View File

@@ -858,21 +858,3 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def map_dict_keys(
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
) -> dict:
"""Maps feature keys from the dataset to the keys used in the model."""
if feature_keys_mapping is None:
return item
features = {}
for key in item:
if key in feature_keys_mapping:
if feature_keys_mapping[key] is not None:
if training_features is None or feature_keys_mapping[key] in training_features:
features[feature_keys_mapping[key]] = item[key]
else:
if training_features is None or key in training_features or pad_key in key:
features[key] = item[key]
return features

View File

@@ -1,409 +0,0 @@
"""
Utils function by Mustafa to refactor
"""
from collections import defaultdict
from typing import Dict, List
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate
from lerobot.datasets.compute_stats import aggregate_stats
OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGE_3 = "observation.image3"
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
"""Reshape features to have a maximum dimension of `max_dim`."""
reshaped_features = {}
for key in features:
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
reshaped_features[key] = features[key]
shape = list(features[key]["shape"])
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
shape[-3] = keys_to_max_dim[key]
shape[-2] = keys_to_max_dim[key]
else:
shape[reshape_dim] = keys_to_max_dim[key]
reshaped_features[key]["shape"] = tuple(shape)
else:
reshaped_features[key] = features[key]
return reshaped_features
def keep_datasets_with_valid_fps(ls_datasets: list, min_fps: int = 1, max_fps: int = 100) -> list:
print(
f"Keeping datasets with fps between {min_fps} and {max_fps}. Considering {len(ls_datasets)} datasets."
)
for ds in ls_datasets:
if ds.fps < min_fps or ds.fps > max_fps:
print(f"Dataset {ds} has invalid fps: {ds.fps}. Removing it.")
ls_datasets.remove(ds)
print(f"Keeping {len(ls_datasets)} datasets with valid fps.")
return ls_datasets
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
"""
Filters datasets to only keep those with consistent feature shapes per robot type.
Args:
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
and `meta.episodes_stats` dictionary.
Returns:
List: Filtered list of datasets with consistent feature shapes.
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
datasets_to_remove = set()
for robot_type in robot_types:
# Collect all stats dicts for this robot type
stats_list = [
ep_stats
for ds in ls_datasets
if ds.meta.info["robot_type"] == robot_type
for ep_stats in ds.meta.episodes_stats.values()
]
if not stats_list:
continue
# Determine the most common shape for each key
all_keys = {key for stats in stats_list for key in stats}
for ds in ls_datasets:
if ds.meta.info["robot_type"] != robot_type:
continue
for key in all_keys:
shape_counter = defaultdict(int)
for stats in stats_list:
value = stats.get(key)
if (
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
): # FIXME(mshukor): check all stats; min, mean, max
shape_counter[value["mean"].shape] += 1
if not shape_counter:
continue
# Identify the most frequent shape
main_shape = max(shape_counter, key=shape_counter.get)
# Flag datasets that don't match the main shape
# for ds in ls_datasets:
first_ep_stats = next(iter(ds.meta.episodes_stats.values()), None)
if not first_ep_stats:
continue
value = first_ep_stats.get(key)
if (
value
and "mean" in value
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
and value["mean"].shape != main_shape
):
datasets_to_remove.add(ds)
break
# Filter out inconsistent datasets
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
print(
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
)
return filtered_datasets, datasets_maks
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
stats = {robot_type: {} for robot_type in robot_types}
for robot_type in robot_types:
robot_type_datasets = []
for ds in ls_datasets:
if ds.meta.info["robot_type"] == robot_type:
robot_type_datasets.extend(list(ds.meta.episodes_stats.values()))
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
stat = aggregate_stats(robot_type_datasets)
stats[robot_type] = stat
return stats
def str_to_torch_dtype(dtype_str):
"""Convert a dtype string to a torch dtype."""
mapping = {
"float32": torch.float32,
"int64": torch.int64,
"int16": torch.int16,
"bool": torch.bool,
"video": torch.float32, # Assuming video is stored as uint8 images
}
return mapping.get(dtype_str, torch.float32) # Default to float32
def create_padded_features(item: dict, features: dict = {}):
for key, ft in features.items():
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
continue
shape = ft["shape"]
if len(shape) == 3: # images to torch format (C, H, W)
shape = (shape[2], shape[0], shape[1])
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
shape = []
if key not in item:
dtype = str_to_torch_dtype(ft["dtype"])
item[key] = torch.zeros(shape, dtype=dtype)
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
if "image" in key: # FIXME(mshukor): support other observations
item[f"{key}_is_pad"] = torch.BoolTensor([False])
else:
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
return item
ROBOT_TYPE_KEYS_MAPPING = {
"lerobot/stanford_hydra_dataset": "static_single_arm",
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
"lerobot/toto": "static_single_arm",
"lerobot/roboturk": "static_single_arm",
"lerobot/jaco_play": "static_single_arm",
"lerobot/taco_play": "static_single_arm_7statedim",
}
def pad_tensor(
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
) -> torch.Tensor:
is_numpy = isinstance(tensor, np.ndarray)
if is_numpy:
tensor = torch.tensor(tensor)
if tensor.ndim == 0:
# Scalar — return as-is, no padding needed
return tensor
pad = max_size - tensor.shape[pad_dim]
if pad > 0:
pad_sizes = (0, pad) # pad right
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
return tensor.numpy() if is_numpy else tensor
def map_dict_keys(
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
) -> dict:
"""Maps feature keys from the dataset to the keys used in the model."""
if feature_keys_mapping is None:
return item
features = {}
for key in item:
if key in feature_keys_mapping:
if feature_keys_mapping[key] is not None:
if training_features is None or feature_keys_mapping[key] in training_features:
features[feature_keys_mapping[key]] = item[key]
else:
if training_features is None or key in training_features or pad_key in key:
features[key] = item[key]
# breakpoint()
return features
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
for t in range(len(velocities) - window_size):
window_mean = velocities[t : t + window_size].mean()
if window_mean > threshold:
return max(0, t - motion_buffer) # include slight context before motion
return 0
import requests
import yaml
def load_yaml_mapping(name: str) -> dict:
"""
Loads a YAML mapping from a Hugging Face repo.
Example: name='features' → https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/features.yaml
"""
url = f"https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/{name}.yaml"
response = requests.get(url)
response.raise_for_status() # raise if the download fails
return yaml.safe_load(response.text)
# Example usage
TASKS_KEYS_MAPPING = load_yaml_mapping("tasks")
FEATURE_KEYS_MAPPING = load_yaml_mapping("features")
EPISODES_DATASET_MAPPING = {
"cadene/droid_1.0.1": list(range(50)),
"danaaubakirova/svla_so100_task5_v3": [
0,
1,
2,
3,
4,
5,
6,
7,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
],
"danaaubakirova/svla_so100_task4_v3": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
],
}
ACTION = "action"
OBS_STATE = "observation.state"
TASK = "task"
ROBOT = "robot_type"
TRAINING_FEATURES = {
0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE],
1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2],
2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
}
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
def pad_tensor_to_shape(tensor: torch.Tensor, target_shape: tuple, pad_value: float = 0.0) -> torch.Tensor:
"""Pads a tensor to the target shape (right/bottom only)."""
pad = []
for actual, target in zip(reversed(tensor.shape), reversed(target_shape), strict=False):
pad.extend([0, max(target - actual, 0)])
return F.pad(tensor, pad, value=pad_value)
def multidataset_collate_fn(
batch: List[Dict[str, torch.Tensor]],
keys_to_max_dim: Dict[str, tuple] = {},
pad_value: float = 0.0,
) -> Dict[str, torch.Tensor]:
"""
Pads tensors to given target shape (if provided), otherwise uses per-batch max.
Supports 1D (e.g. action), 3D (e.g. [C,H,W] images).
"""
collated_batch = [{} for _ in range(len(batch))]
batch_keys = batch[0].keys()
for key in batch_keys:
values = [sample[key] for sample in batch]
sample = values[0]
if not isinstance(sample, torch.Tensor):
for i in range(len(batch)):
collated_batch[i][key] = values[i]
continue
# use user-specified shape if available
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
target_shape = keys_to_max_dim[key]
else:
# compute per-batch max shape
target_shape = tuple(max(v.shape[i] for v in values) for i in range(sample.ndim))
for i in range(len(batch)):
collated_batch[i][key] = pad_tensor_to_shape(values[i], target_shape, pad_value=pad_value)
return default_collate(collated_batch)

View File

@@ -0,0 +1,401 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from dataclasses import dataclass
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
from lerobot.motors import MotorCalibration, MotorsBus
BAR_LEN, BAR_THICKNESS = 450, 8
HANDLE_R = 10
BRACKET_W, BRACKET_H = 6, 14
TRI_W, TRI_H = 12, 14
BTN_W, BTN_H = 60, 22
SAVE_W, SAVE_H = 80, 28
LOAD_W = 80
DD_W, DD_H = 160, 28
TOP_GAP = 50
PADDING_Y, TOP_OFFSET = 70, 60
FONT_SIZE, FPS = 20, 60
BG_COLOR = (30, 30, 30)
BAR_RED, BAR_GREEN = (200, 60, 60), (60, 200, 60)
HANDLE_COLOR, TEXT_COLOR = (240, 240, 240), (250, 250, 250)
TICK_COLOR = (250, 220, 40)
BTN_COLOR, BTN_COLOR_HL = (80, 80, 80), (110, 110, 110)
DD_COLOR, DD_COLOR_HL = (70, 70, 70), (100, 100, 100)
def dist(a, b):
return math.hypot(a[0] - b[0], a[1] - b[1])
@dataclass
class RangeValues:
min_v: int
pos_v: int
max_v: int
class RangeSlider:
"""One motor = one slider row"""
def __init__(self, motor, idx, res, calibration, present, label_pad, base_y):
import pygame
self.motor = motor
self.res = res
self.x0 = 40 + label_pad
self.x1 = self.x0 + BAR_LEN
self.y = base_y + idx * PADDING_Y
self.min_v = calibration.range_min
self.max_v = calibration.range_max
self.pos_v = max(self.min_v, min(present, self.max_v))
self.min_x = self._pos_from_val(self.min_v)
self.max_x = self._pos_from_val(self.max_v)
self.pos_x = self._pos_from_val(self.pos_v)
self.min_btn = pygame.Rect(self.x0 - BTN_W - 6, self.y - BTN_H // 2, BTN_W, BTN_H)
self.max_btn = pygame.Rect(self.x1 + 6, self.y - BTN_H // 2, BTN_W, BTN_H)
self.drag_min = self.drag_max = self.drag_pos = False
self.tick_val = present
self.font = pygame.font.Font(None, FONT_SIZE)
def _val_from_pos(self, x):
return round((x - self.x0) / BAR_LEN * self.res)
def _pos_from_val(self, v):
return self.x0 + (v / self.res) * BAR_LEN
def set_tick(self, v):
self.tick_val = max(0, min(v, self.res))
def _triangle_hit(self, pos):
import pygame
tri_top = self.y - BAR_THICKNESS // 2 - 2
return pygame.Rect(self.pos_x - TRI_W // 2, tri_top - TRI_H, TRI_W, TRI_H).collidepoint(pos)
def handle_event(self, e):
import pygame
if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
if self.min_btn.collidepoint(e.pos):
self.min_x, self.min_v = self.pos_x, self.pos_v
return
if self.max_btn.collidepoint(e.pos):
self.max_x, self.max_v = self.pos_x, self.pos_v
return
if dist(e.pos, (self.min_x, self.y)) <= HANDLE_R:
self.drag_min = True
elif dist(e.pos, (self.max_x, self.y)) <= HANDLE_R:
self.drag_max = True
elif self._triangle_hit(e.pos):
self.drag_pos = True
elif e.type == pygame.MOUSEBUTTONUP and e.button == 1:
self.drag_min = self.drag_max = self.drag_pos = False
elif e.type == pygame.MOUSEMOTION:
x = e.pos[0]
if self.drag_min:
self.min_x = max(self.x0, min(x, self.pos_x))
elif self.drag_max:
self.max_x = min(self.x1, max(x, self.pos_x))
elif self.drag_pos:
self.pos_x = max(self.min_x, min(x, self.max_x))
self.min_v = self._val_from_pos(self.min_x)
self.max_v = self._val_from_pos(self.max_x)
self.pos_v = self._val_from_pos(self.pos_x)
def _draw_button(self, surf, rect, text):
import pygame
clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR
pygame.draw.rect(surf, clr, rect, border_radius=4)
t = self.font.render(text, True, TEXT_COLOR)
surf.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2))
def draw(self, surf):
import pygame
# motor name above set-min button (right-aligned)
name_surf = self.font.render(self.motor, True, TEXT_COLOR)
surf.blit(
name_surf,
(self.min_btn.right - name_surf.get_width(), self.min_btn.y - name_surf.get_height() - 4),
)
# bar + active section
pygame.draw.rect(surf, BAR_RED, (self.x0, self.y - BAR_THICKNESS // 2, BAR_LEN, BAR_THICKNESS))
pygame.draw.rect(
surf, BAR_GREEN, (self.min_x, self.y - BAR_THICKNESS // 2, self.max_x - self.min_x, BAR_THICKNESS)
)
# tick
tick_x = self._pos_from_val(self.tick_val)
pygame.draw.line(
surf,
TICK_COLOR,
(tick_x, self.y - BAR_THICKNESS // 2 - 4),
(tick_x, self.y + BAR_THICKNESS // 2 + 4),
2,
)
# brackets
for x, sign in ((self.min_x, +1), (self.max_x, -1)):
pygame.draw.line(
surf, HANDLE_COLOR, (x, self.y - BRACKET_H // 2), (x, self.y + BRACKET_H // 2), 2
)
pygame.draw.line(
surf,
HANDLE_COLOR,
(x, self.y - BRACKET_H // 2),
(x + sign * BRACKET_W, self.y - BRACKET_H // 2),
2,
)
pygame.draw.line(
surf,
HANDLE_COLOR,
(x, self.y + BRACKET_H // 2),
(x + sign * BRACKET_W, self.y + BRACKET_H // 2),
2,
)
# triangle ▼
tri_top = self.y - BAR_THICKNESS // 2 - 2
pygame.draw.polygon(
surf,
HANDLE_COLOR,
[
(self.pos_x, tri_top),
(self.pos_x - TRI_W // 2, tri_top - TRI_H),
(self.pos_x + TRI_W // 2, tri_top - TRI_H),
],
)
# numeric labels
fh = self.font.get_height()
pos_y = tri_top - TRI_H - 4 - fh
txts = [
(self.min_v, self.min_x, self.y - BRACKET_H // 2 - 4 - fh),
(self.max_v, self.max_x, self.y - BRACKET_H // 2 - 4 - fh),
(self.pos_v, self.pos_x, pos_y),
]
for v, x, y in txts:
s = self.font.render(str(v), True, TEXT_COLOR)
surf.blit(s, (x - s.get_width() // 2, y))
# buttons
self._draw_button(surf, self.min_btn, "set min")
self._draw_button(surf, self.max_btn, "set max")
# external
def values(self) -> RangeValues:
return RangeValues(self.min_v, self.pos_v, self.max_v)
class RangeFinderGUI:
def __init__(self, bus: MotorsBus, groups: dict[str, list[str]] | None = None):
import pygame
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
bus.connect()
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")
# ui rects
self.save_btn = pygame.Rect(width - SAVE_W - 10, 10, SAVE_W, SAVE_H)
self.load_btn = pygame.Rect(self.save_btn.left - LOAD_W - 10, 10, LOAD_W, SAVE_H)
self.dd_btn = pygame.Rect(width // 2 - DD_W // 2, 10, DD_W, DD_H)
self.dd_open = False # dropdown expanded?
self.clock = pygame.time.Clock()
self._build_sliders()
self._adjust_height()
def _adjust_height(self):
import pygame
motors = self.groups[self.current_group]
new_h = self.base_y + PADDING_Y * len(motors) + 40
if new_h != self.screen.get_height():
w = self.screen.get_width()
self.screen = pygame.display.set_mode((w, new_h))
def _build_sliders(self):
self.sliders: list[RangeSlider] = []
motors = self.groups[self.current_group]
for i, m in enumerate(motors):
self.sliders.append(
RangeSlider(
motor=m,
idx=i,
res=self.res_table[self.bus.motors[m].model] - 1,
calibration=self.calibration[m],
present=self.present_cache[m],
label_pad=self.label_pad,
base_y=self.base_y,
)
)
def _draw_dropdown(self):
import pygame
# collapsed box
hover = self.dd_btn.collidepoint(pygame.mouse.get_pos())
pygame.draw.rect(self.screen, DD_COLOR_HL if hover else DD_COLOR, self.dd_btn, border_radius=6)
txt = self.font.render(self.current_group, True, TEXT_COLOR)
self.screen.blit(
txt, (self.dd_btn.centerx - txt.get_width() // 2, self.dd_btn.centery - txt.get_height() // 2)
)
tri_w, tri_h = 12, 6
cx = self.dd_btn.right - 14
cy = self.dd_btn.centery + 1
pygame.draw.polygon(
self.screen,
TEXT_COLOR,
[(cx - tri_w // 2, cy - tri_h // 2), (cx + tri_w // 2, cy - tri_h // 2), (cx, cy + tri_h // 2)],
)
if not self.dd_open:
return
# expanded list
for i, name in enumerate(self.group_names):
item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H)
clr = DD_COLOR_HL if item_rect.collidepoint(pygame.mouse.get_pos()) else DD_COLOR
pygame.draw.rect(self.screen, clr, item_rect)
t = self.font.render(name, True, TEXT_COLOR)
self.screen.blit(
t, (item_rect.centerx - t.get_width() // 2, item_rect.centery - t.get_height() // 2)
)
def _handle_dropdown_event(self, e):
import pygame
if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
if self.dd_btn.collidepoint(e.pos):
self.dd_open = not self.dd_open
return True
if self.dd_open:
for i, name in enumerate(self.group_names):
item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H)
if item_rect.collidepoint(e.pos):
if name != self.current_group:
self.current_group = name
self._build_sliders()
self._adjust_height()
self.dd_open = False
return True
self.dd_open = False
return False
def _save_current(self):
for s in self.sliders:
self.calibration[s.motor].range_min = s.min_v
self.calibration[s.motor].range_max = s.max_v
with self.bus.torque_disabled():
self.bus.write_calibration(self.calibration)
def _load_current(self):
self.calibration = self.bus.read_calibration()
for s in self.sliders:
s.min_v = self.calibration[s.motor].range_min
s.max_v = self.calibration[s.motor].range_max
s.min_x = s._pos_from_val(s.min_v)
s.max_x = s._pos_from_val(s.max_v)
def run(self) -> dict[str, MotorCalibration]:
import pygame
while True:
for e in pygame.event.get():
if e.type == pygame.QUIT:
pygame.quit()
return self.calibration
if self._handle_dropdown_event(e):
continue
if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
if self.save_btn.collidepoint(e.pos):
self._save_current()
elif self.load_btn.collidepoint(e.pos):
self._load_current()
for s in self.sliders:
s.handle_event(e)
# live goal write while dragging
for s in self.sliders:
if s.drag_pos:
self.bus.write("Goal_Position", s.motor, s.pos_v, normalize=False)
# tick update
for s in self.sliders:
pos = self.bus.read("Present_Position", s.motor, normalize=False)
s.set_tick(pos)
self.present_cache[s.motor] = pos
# ─ drawing
self.screen.fill(BG_COLOR)
for s in self.sliders:
s.draw(self.screen)
self._draw_dropdown()
# load / save buttons
for rect, text in ((self.load_btn, "LOAD"), (self.save_btn, "SAVE")):
clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR
pygame.draw.rect(self.screen, clr, rect, border_radius=6)
t = self.font.render(text, True, TEXT_COLOR)
self.screen.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2))
pygame.display.flip()
self.clock.tick(FPS)

View File

@@ -162,11 +162,11 @@ class DynamixelMotorsBus(MotorsBus):
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def configure_motors(self) -> None:
def configure_motors(self, return_delay_time=0) -> None:
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
for motor in self.motors:
self.write("Return_Delay_Time", motor, 0)
self.write("Return_Delay_Time", motor, return_delay_time)
@property
def is_calibrated(self) -> bool:
@@ -190,13 +190,14 @@ class DynamixelMotorsBus(MotorsBus):
return calibration
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
for motor, calibration in calibration_dict.items():
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
self.calibration = calibration_dict
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):

View File

@@ -164,8 +164,9 @@ class FeetechMotorsBus(MotorsBus):
)
def _handshake(self) -> None:
self._assert_motors_exist()
self._assert_same_firmware()
# self._assert_motors_exist()
# self._assert_same_firmware()
return
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
if self.protocol_version == 0:
@@ -219,94 +220,70 @@ class FeetechMotorsBus(MotorsBus):
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def configure_motors(self) -> None:
def configure_motors(self, return_delay_time=0, maximum_acceleration=254, acceleration=254) -> None:
for motor in self.motors:
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
self.write("Return_Delay_Time", motor, 0)
# self.write("Return_Delay_Time", motor, 0) # THIS DOES NOT WORK FOR HLS3625
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
# Note: this address is not in the official STS3215 Memory Table
self.write("Maximum_Acceleration", motor, 254)
self.write("Acceleration", motor, 254)
if self.protocol_version == 0:
self.write("Maximum_Acceleration", motor, maximum_acceleration)
self.write("Acceleration", motor, acceleration)
@property
def is_calibrated(self) -> bool:
motors_calibration = self.read_calibration()
if set(motors_calibration) != set(self.calibration):
return False
same_ranges = all(
self.calibration[motor].range_min == cal.range_min
and self.calibration[motor].range_max == cal.range_max
for motor, cal in motors_calibration.items()
)
if self.protocol_version == 1:
return same_ranges
same_offsets = all(
self.calibration[motor].homing_offset == cal.homing_offset
for motor, cal in motors_calibration.items()
)
return same_ranges and same_offsets
# Check if calibration data has been loaded from file
return bool(self.calibration)
def read_calibration(self) -> dict[str, MotorCalibration]:
offsets, mins, maxes = {}, {}, {}
for motor in self.motors:
mins[motor] = self.read("Min_Position_Limit", motor, normalize=False)
maxes[motor] = self.read("Max_Position_Limit", motor, normalize=False)
offsets[motor] = (
self.read("Homing_Offset", motor, normalize=False) if self.protocol_version == 0 else 0
)
# Return empty calibration - we don't read from motors anymore
calibration = {}
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
homing_offset=0,
range_min=0,
range_max=4095, # Default max resolution
)
return calibration
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
for motor, calibration in calibration_dict.items():
if self.protocol_version == 0:
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
# Only update the in-memory calibration, don't write to motors
self.calibration = calibration_dict
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
"""
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
Calculate homing offsets such that the current position becomes 0 degrees.
For Feetech motors:
- The homing offset is subtracted from the raw position during normalization
- So to make current position = 0 degrees, homing_offset = current_raw_position
"""
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
half_turn_homings[motor] = pos - int(max_res / 2)
# The homing offset should be the current position
# This way, when we normalize: (pos - homing_offset) = 0
half_turn_homings[motor] = pos
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 5) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
# self.write("Lock", motor, 0, num_retry=num_retry)
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 5) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor_id, 0, num_retry=num_retry)
# addr, length = get_address(self.model_ctrl_table, model, "Lock")
# self._write(addr, length, motor_id, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 5) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
# self.write("Lock", motor, 1, num_retry=num_retry)
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:

View File

@@ -151,6 +151,95 @@ SCS_SERIES_CONTROL_TABLE = {
"Acceleration_2": (83, 1), # don't know what that is
}
# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0
HLS_SERIES_CONTROL_TABLE = {
# Version Information (0-4) - read-only
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # (0, 1) read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # (1, 1) read-only
"End_Type": (2, 1), # read-only - 0 represents little-endian storage
"Model_Number": MODEL_NUMBER, # (3, 2) read-only
# EPROM configuration (5-39)
"ID": (5, 1), # Main ID - unique identifier on bus
"Baud_Rate": (6, 1), # 0-7 for different baud rates
"Secondary_ID": (7, 1), # Secondary ID for write instructions
"Response_Status_Level": (8, 1), # 0: limited response, 1: full response
"Min_Position_Limit": (9, 2), # 0-4094 (0.087 degrees per unit)
"Max_Position_Limit": (11, 2), # 1-4095 (0.087 degrees per unit)
"Max_Temperature_Limit": (13, 1), # 0-100 (°C)
"Max_Voltage_Limit": (14, 1), # 0-254 (0.1V per unit)
"Min_Voltage_Limit": (15, 1), # 0-254 (0.1V per unit)
"Max_Torque_Limit": (16, 2), # 0-1000 (0.1% per unit)
"Phase": (18, 1), # Special function byte for motor phase configuration
"Unloading_Condition": (19, 1), # Bit flags for protection conditions
"LED_Alarm_Condition": (20, 1), # Bit flags for LED alarm conditions
"P_Coefficient": (21, 1), # Position ring P proportional coefficient
"D_Coefficient": (22, 1), # Position ring D differential coefficient
"I_Coefficient": (23, 1), # Position ring I integral coefficient
"Minimum_Startup_Force": (24, 1), # 0-254 (0.1% per unit)
"Point_Limit_Value": (25, 1), # 0-254 - maximum point value = point_limit * 4
"CW_Dead_Zone": (26, 1), # 0-16 (0.087 degrees per unit)
"CCW_Dead_Zone": (27, 1), # 0-16 (0.087 degrees per unit)
"Protection_Current": (28, 2), # 0-2047 (6.5 mA per unit)
"Angle_Resolution": (30, 1), # 1-128 - amplification coefficient
"Homing_Offset": (31, 2), # -4095 to 4095 (0.087 degrees per unit)
"Operating_Mode": (33, 1), # 0: position, 1: speed, 2: current, 3: PWM
"P_Coefficient_Curr": (34, 1), # Current ring P proportional coefficient
"I_Coefficient_Curr": (35, 1), # Current ring I integral coefficient
# Address 36 undefined
"Speed_P_Coefficient": (37, 1), # Speed closed-loop P proportional coefficient
"Overcurrent_Protection_Time": (38, 1), # 0-254 (10ms per unit)
"Speed_I_Coefficient": (39, 1), # Speed closed-loop I integral coefficient
# SRAM control (40-55)
"Torque_Enable": (40, 1), # 0: off, 1: on, 2: damping
"Acceleration": (41, 1), # 0-254 (8.7 degrees/second² per unit)
"Goal_Position": (42, 2), # -32767 to 32767 (0.087 degrees per unit)
"Target_Torque": (44, 2), # -2047 to 2047 (6.5 mA per unit)
"Goal_Velocity": (46, 2), # -32767 to 32767 (0.732 RPM per unit)
"Torque_Limit": (48, 2), # 0-1000 (0.1% per unit)
"P_Coefficient_Ring": (50, 1), # Motor position ring proportional coefficient
"D_Coefficient_Ring": (51, 1), # Motor position ring differential coefficient
"I_Coefficient_Ring": (52, 1), # Motor position ring integral coefficient
"km": (53, 1), # 0: position+current dual loop, 1: position single loop
# Address 54 undefined
"Lock": (55, 1), # 0: close write lock, 1: open write lock
# SRAM feedback (56-73) - read-only
"Present_Position": (56, 2), # read-only - current absolute position
"Present_Velocity": (58, 2), # read-only - current motor rotation speed
"Present_Load": (60, 2), # read-only - current load (0.1% per unit)
"Present_Voltage": (62, 1), # read-only - current voltage (0.1V per unit)
"Present_Temperature": (63, 1), # read-only - current temperature (°C)
"Async_Write_Flag": (64, 1), # read-only - async write instruction flag
"Status": (65, 1), # read-only - servo status bit flags
"Moving": (66, 1), # read-only - movement status flags
"Target_Position": (67, 2), # read-only - current target position
"Present_Current": (69, 2), # read-only - current motor phase current (6.5 mA per unit)
# Address 71 undefined
"Present_Bias": (73, 2), # read-only - current 0-point offset value
# Factory parameters (77-86) - read-only
"VFk_x10": (77, 1), # read-only - factory parameter
"vKgI": (78, 1), # read-only - factory parameter
"PFk_x10": (79, 1), # read-only - factory parameter
"Moving_Velocity_Threshold": (80, 1), # read-only - factory parameter
"DTs_ms": (81, 1), # read-only - factory parameter
"eFk_x10": (82, 1), # read-only - factory parameter
"Vk_ms": (83, 1), # read-only - factory parameter
"Maximum_Velocity_Limit": (84, 1), # read-only - factory parameter
"Maximum_Acceleration": (85, 1), # read-only - factory parameter
"Acceleration_Multiplier": (86, 1), # read-only - factory parameter
}
# HLS series baud rate table (same as STS/SMS series)
HLS_SERIES_BAUDRATE_TABLE = {
1_000_000: 0,
500_000: 1,
250_000: 2,
128_000: 3,
115_200: 4,
76_800: 5, # Note: HLS documentation mentions 76800 instead of 57600
57_600: 6,
38_400: 7,
}
STS_SMS_SERIES_BAUDRATE_TABLE = {
1_000_000: 0,
500_000: 1,
@@ -181,6 +270,7 @@ MODEL_CONTROL_TABLE = {
"sts3250": STS_SMS_SERIES_CONTROL_TABLE,
"scs0009": SCS_SERIES_CONTROL_TABLE,
"sm8512bl": STS_SMS_SERIES_CONTROL_TABLE,
"hls3625": HLS_SERIES_CONTROL_TABLE,
}
MODEL_RESOLUTION = {
@@ -189,8 +279,9 @@ MODEL_RESOLUTION = {
"scs_series": 1024,
"sts3215": 4096,
"sts3250": 4096,
"sm8512bl": 65536,
"sm8512bl": 4096,
"scs0009": 1024,
"hls3625": 4096,
}
MODEL_BAUDRATE_TABLE = {
@@ -201,6 +292,7 @@ MODEL_BAUDRATE_TABLE = {
"sts3215": STS_SMS_SERIES_BAUDRATE_TABLE,
"sts3250": STS_SMS_SERIES_BAUDRATE_TABLE,
"scs0009": SCS_SERIES_BAUDRATE_TABLE,
"hls3625": HLS_SERIES_BAUDRATE_TABLE,
}
# Sign-Magnitude encoding bits
@@ -210,6 +302,18 @@ STS_SMS_SERIES_ENCODINGS_TABLE = {
"Present_Velocity": 15,
}
# HLS series sign-magnitude encoding bits
HLS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 15, # BIT15 represents positive/negative direction
"Goal_Position": 15, # BIT15 represents positive/negative direction
"Target_Torque": 15, # BIT15 represents positive/negative direction in constant current mode
"Goal_Velocity": 15, # BIT15 represents positive/negative direction in constant speed mode
"Present_Position": 15, # BIT15 represents positive/negative direction
"Present_Velocity": 15, # BIT15 represents positive/negative direction
"Present_Current": 15, # BIT15 represents positive/negative direction
"Present_Load": 10, # BIT10 represents positive/negative direction
}
MODEL_ENCODING_TABLE = {
"sts_series": STS_SMS_SERIES_ENCODINGS_TABLE,
"sms_series": STS_SMS_SERIES_ENCODINGS_TABLE,
@@ -218,6 +322,7 @@ MODEL_ENCODING_TABLE = {
"sts3250": STS_SMS_SERIES_ENCODINGS_TABLE,
"sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE,
"scs0009": {},
"hls3625": HLS_SERIES_ENCODINGS_TABLE,
}
SCAN_BAUDRATES = [
@@ -239,6 +344,7 @@ MODEL_NUMBER_TABLE = {
"sts3250": 2825,
"sm8512bl": 11272,
"scs0009": 1284,
"hls3625": 3338,
}
MODEL_PROTOCOL = {
@@ -249,4 +355,5 @@ MODEL_PROTOCOL = {
"sts3250": 0,
"sm8512bl": 0,
"scs0009": 1,
"hls3625": 0, # Uses FT-SCS protocol
}

View File

@@ -83,6 +83,9 @@ class MotorNormMode(str, Enum):
DEGREES = "degrees"
COUNT_TO_DEG = 0.087 # 1 encoder count = 0.087 °
@dataclass
class MotorCalibration:
id: int
@@ -441,8 +444,8 @@ class MotorsBus(abc.ABC):
try:
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
elif handshake:
self._handshake()
# elif handshake:
# self._handshake()
except (FileNotFoundError, OSError, serial.SerialException) as e:
raise ConnectionError(
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
@@ -586,7 +589,7 @@ class MotorsBus(abc.ABC):
pass
@contextmanager
def torque_disabled(self):
def torque_disabled(self, motors: int | str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -596,11 +599,11 @@ class MotorsBus(abc.ABC):
... # Safe operations here
... pass
"""
self.disable_torque()
self.disable_torque(motors)
try:
yield
finally:
self.enable_torque()
self.enable_torque(motors)
def set_timeout(self, timeout_ms: int | None = None):
"""Change the packet timeout used by the SDK.
@@ -653,12 +656,13 @@ class MotorsBus(abc.ABC):
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
"""Write calibration parameters to the motors and cache them.
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration parameters to the motors and optionally cache them.
Args:
calibration_dict (dict[str, MotorCalibration]): Calibration obtained from
:pymeth:`read_calibration` or crafted by the user.
cache (bool, optional): Save the calibration to :pyattr:`calibration`. Defaults to True.
"""
pass
@@ -710,9 +714,8 @@ class MotorsBus(abc.ABC):
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
# Don't write to motors, just return the calculated offsets
return homing_offsets
@abc.abstractmethod
@@ -781,21 +784,32 @@ class MotorsBus(abc.ABC):
motor = self._id_to_name(id_)
min_ = self.calibration[motor].range_min
max_ = self.calibration[motor].range_max
homing_offset = self.calibration[motor].homing_offset
drive_mode = self.apply_drive_mode and self.calibration[motor].drive_mode
if max_ == min_:
raise ValueError(f"Invalid calibration for motor '{motor}': min and max are equal.")
bounded_val = min(max_, max(min_, val))
if self.motors[motor].norm_mode is MotorNormMode.RANGE_M100_100:
bounded_val = min(max_, max(min_, val))
norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
normalized_values[id_] = -norm if drive_mode else norm
elif self.motors[motor].norm_mode is MotorNormMode.RANGE_0_100:
bounded_val = min(max_, max(min_, val))
norm = ((bounded_val - min_) / (max_ - min_)) * 100
normalized_values[id_] = 100 - norm if drive_mode else norm
elif self.motors[motor].norm_mode is MotorNormMode.DEGREES:
mid = (min_ + max_) / 2
max_res = self.model_resolution_table[self._id_to_model(id_)] - 1
normalized_values[id_] = (val - mid) * 360 / max_res
# For motors without wrap-around handling
# The homing offset becomes 0 degrees
# Calculate difference from homing position
diff = val - homing_offset
# Convert to degrees
deg = diff * COUNT_TO_DEG
# Apply drive mode if needed
normalized_values[id_] = -deg if drive_mode else deg
else:
raise NotImplementedError
@@ -810,7 +824,9 @@ class MotorsBus(abc.ABC):
motor = self._id_to_name(id_)
min_ = self.calibration[motor].range_min
max_ = self.calibration[motor].range_max
homing_offset = self.calibration[motor].homing_offset
drive_mode = self.apply_drive_mode and self.calibration[motor].drive_mode
if max_ == min_:
raise ValueError(f"Invalid calibration for motor '{motor}': min and max are equal.")
@@ -823,9 +839,22 @@ class MotorsBus(abc.ABC):
bounded_val = min(100.0, max(0.0, val))
unnormalized_values[id_] = int((bounded_val / 100) * (max_ - min_) + min_)
elif self.motors[motor].norm_mode is MotorNormMode.DEGREES:
mid = (min_ + max_) / 2
max_res = self.model_resolution_table[self._id_to_model(id_)] - 1
unnormalized_values[id_] = int((val * max_res / 360) + mid)
# For motors without wrap-around, simple conversion back
# Apply drive mode if needed
val = -val if drive_mode else val
# Convert degrees to raw counts
raw_counts = int(round(val / COUNT_TO_DEG))
# Add back the homing offset
raw_counts_with_offset = raw_counts + homing_offset
# Ensure value stays within calibrated motor range
# Use the calibration min/max if available
if min_ is not None and max_ is not None:
raw_counts_with_offset = max(min_, min(max_, raw_counts_with_offset))
unnormalized_values[id_] = raw_counts_with_offset
else:
raise NotImplementedError

View File

@@ -16,6 +16,5 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@@ -107,7 +107,7 @@ class ACTPolicy(PreTrainedPolicy):
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@@ -132,7 +132,7 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -485,12 +485,10 @@ class ACT(nn.Module):
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.config.image_features:
all_cam_features = []
all_cam_pos_embeds = []
# For a list of images, the H and W may vary but H*W is constant.
# NOTE: If modifying this section, verify on MPS devices that
# gradients remain stable (no explosions or NaNs).
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
@@ -500,11 +498,10 @@ class ACT(nn.Module):
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
# Extend immediately instead of accumulating and concatenating
# Convert to list to extend properly
encoder_in_tokens.extend(list(cam_features))
encoder_in_pos_embed.extend(list(cam_pos_embed))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)

View File

@@ -99,7 +99,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
@@ -111,7 +111,7 @@ class DiffusionPolicy(PreTrainedPolicy):
return actions
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -32,7 +32,6 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -75,10 +74,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "smolvla2":
from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy
return SmolVLA2Policy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -100,8 +95,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "smolvla2":
return SmolVLA2Config(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
else:

View File

@@ -149,7 +149,7 @@ class Normalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
@@ -224,7 +224,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():

View File

@@ -260,12 +260,12 @@ class PI0Policy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -192,12 +192,12 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -76,7 +76,7 @@ class SACPolicy(
"""Reset the policy"""
pass
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")

View File

@@ -413,6 +413,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()
@@ -422,7 +423,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions = self._get_action_chunk(batch, noise)
return actions
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -1,191 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@dataclass
class PEFTConfig:
r: int = 4
lora_alpha: int = 16
lora_dropout: float = 0.1
target_modules: str = "q_proj,v_proj"
@PreTrainedConfig.register_subclass("smolvla2")
@dataclass
class SmolVLA2Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512)
# Add empty images. Used by smolvla_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
proj_width: int = 480
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5 # 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
optimizer_lr_vlm: float = 0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
checkpoint_path: str = None
peft_method: str = ""
peft_config: PEFTConfig = field(default_factory=PEFTConfig)
peft_target_model: str = ""
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
attention_mode: str = "cross_attn"
prefix_length: int = -1
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
num_vlm_layers: int = 16
past_obs_keys: str = "image"
add_local_special_image_tokens: bool = False
reverse_images_order: bool = False
state_to_prefix: bool = False
pad_language_to: str = "longest" # "max_length"
causal_action_attention_mask: bool = False
self_attn_every_n_layers: int = -1 # Number of layers used in the VLM (first num_vlm_layers layers)
# self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
robot_type: str = ""
self_attn_only_actions: bool = False
causal_attention_on_history: bool = False
predict_relative_actions: bool = False
relative_actions_mode: str = "first"
shuffle_camera_positions: bool = False
vlm_img_size: int = -1
regression_loss: bool = False
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return [0]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,600 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Optional
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
from peft import LoraConfig, TaskType, get_peft_model
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
config = self.vlm.config
else:
config = AutoConfig.from_pretrained(model_id)
self.vlm = SmolVLMForConditionalGeneration(config=config)
self.processor = AutoProcessor.from_pretrained(model_id)
if num_vlm_layers > 0:
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
self.config = config
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
if "cross" in attention_mode:
# Reshape qkv projections to have the same input dimension as the vlm
for layer_idx in range(len(self.lm_expert.layers)):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
self.num_attention_heads = self.config.text_config.num_attention_heads
self.num_key_value_heads = self.config.text_config.num_key_value_heads
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_mode = attention_mode
self.expert_hidden_size = lm_expert_config.hidden_size
self.set_requires_grad()
def configure_peft(self, config):
# return model
self.peft_method = config.peft_method
self.peft_target_model = config.peft_target_model
if "lora" in self.peft_method:
peft_config = config.peft_config
target_modules = peft_config.target_modules
if not isinstance(target_modules, list):
target_modules = target_modules.split(",")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
r=peft_config.r, # The rank of the low-rank adaptation
lora_alpha=peft_config.lora_alpha, # Scaling factor
lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
target_modules=target_modules, # The components where LoRA is applied
exclude_modules=[
"lm_expert",
"model.lm_expert.model.layers",
], # FIXME(mshukor): this does not work for now
)
self.lora_config = lora_config
# Apply LoRA and ensure only LoRA parameters are trainable
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config)
else:
self.vlm = get_peft_model(self.vlm, lora_config)
# assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here?
for name, param in self.vlm.named_parameters():
if (
"lora" in name and "text_model.model.layers.17" not in name
): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
param.requires_grad = True
else:
param.requires_grad = False
def merge_lora_weights(self):
"""
Merge LoRA weights into the base model.
"""
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload()
else:
self.vlm = self.vlm.merge_and_unload()
def get_vlm_model(
self,
):
if hasattr(self.vlm.model, "model"): # When using peft
return self.vlm.model.model
else:
return self.vlm.model
def set_requires_grad(self):
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
for params in self.get_vlm_model().vision_model.parameters():
params.requires_grad = False
if self.train_expert_only:
self.vlm.eval()
for params in self.vlm.parameters():
params.requires_grad = False
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any(k in name for k in frozen_layers):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
if "lm_head" in name:
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
if self.train_expert_only:
self.vlm.eval()
def embed_image(self, image: torch.Tensor):
patch_attention_mask = None
# Get sequence from the vision encoder
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
seq_len = query_states.shape[1]
if seq_len < position_ids.shape[1]:
_position_ids = position_ids[:, :seq_len]
_attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
_position_ids = position_ids
_attention_mask = attention_mask
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_)
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
)
return [att_output], past_key_values
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# B,L,H,D with L sequence length, H number of heads, D head dim
query_states = apply_rope(query_state, position_id)
key_states = apply_rope(key_state, position_id)
att_output = attention_interface(
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_outputs.append(att_output)
else:
expert_position_id = position_ids
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
expert_input_shape = expert_hidden_states.shape[:-1]
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
att_outputs.append(None)
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list:
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
for i in range(self.num_vlm_layers):
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
expert_layer = None
else:
expert_layer_index = i // multiple_of if multiple_of > 0 else i
expert_layer = models[1].layers[expert_layer_index]
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
continue
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
att_out = att_output[:, start:end]
out_emb = layer.self_attn.o_proj(att_out)
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end if len(att_outputs) == 1 else 0
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
attention_interface = self.eager_attention_forward
return attention_interface
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -110,7 +110,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}

View File

@@ -124,14 +124,14 @@ class VQBeTPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.action_chunk_size),
}
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -21,7 +21,7 @@ Example:
python -m lerobot.record \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480}}" \
--robot.id=black \
--dataset.repo_id=aliberts/record-test \
--dataset.num_episodes=2 \
@@ -33,6 +33,41 @@ python -m lerobot.record \
# <- Policy optional if you want to record with a policy \
# --policy.path=${HF_USER}/my_policy \
```
Example with bilateral teleoperation:
```shell
python -m lerobot.record \
--robot.type=so101_follower_t \
--robot.port=/dev/tty.usbmodem58760432961 \
--robot.id=follower_arm_torque \
--dataset.repo_id=pepijn/bilateral-teleop-test \
--dataset.num_episodes=5 \
--dataset.single_task="Wipe the table" \
--biteleop=true \
--teleop.type=so101_follower_t \
--teleop.port=/dev/tty.usbmodem58760432571 \
--teleop.id=leader_arm_torque \
--dataset.fps=100 \
--robot.cameras="{side: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 100}}" \
--display_data=true
```
Example Eval with bilateral teleoperation:
```
python -m lerobot.record \
--robot.type=so101_follower_t \
--robot.port=/dev/tty.usbmodem58760432961 \
--robot.id=follower_arm_torque \
--robot.cameras="{side: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 100}}" \
--display_data=true \
--dataset.repo_id=pepijn223/eval_bilateral-wipe-large \
--dataset.single_task="Wipe the table" \
--policy.path=pepijn223/bilateral-wipe-large-single \
--dataset.fps=100 \
--biteleop=true
```
"""
import logging
@@ -57,14 +92,18 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
hope_jr,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
so101_follower_torque,
)
from lerobot.robots.so101_follower_torque import SO101FollowerT
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
homunculus,
koch_leader,
make_teleoperator_from_config,
so100_leader,
@@ -87,6 +126,24 @@ from lerobot.utils.utils import (
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
def split_interleaved_action(vec, motors):
"""
vec : 1D tensor/array, length = 3*len(motors)
motors : ['shoulder_pan', 'shoulder_lift', …]
returns : pos, vel, tau (three dicts keyed by joint name)
"""
pos = {}
vel = {}
tau = {}
for i, j in enumerate(motors):
base = 3 * i
pos[j] = float(vec[base + 0])
vel[j] = float(vec[base + 1])
tau[j] = float(vec[base + 2])
return pos, vel, tau
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
@@ -131,7 +188,7 @@ class RecordConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Whether to control the robot with a teleoperator
teleop: TeleoperatorConfig | None = None
teleop: TeleoperatorConfig | RobotConfig | None = None
# Whether to control the robot with a policy
policy: PreTrainedConfig | None = None
# Display all cameras on screen
@@ -140,6 +197,8 @@ class RecordConfig:
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Enable bilateral teleoperation with force feedback
biteleop: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -164,15 +223,23 @@ def record_loop(
events: dict,
fps: int,
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | List[Teleoperator] | None = None,
teleop: Teleoperator | List[Teleoperator] | Robot | None = None,
policy: PreTrainedPolicy | None = None,
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
biteleop: bool = False,
):
if dataset is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
if biteleop and policy is None:
if not isinstance(robot, SO101FollowerT):
raise ValueError(
"Bilateral teleoperation requires both robot and teleop to be of type SO101FollowerT"
)
logging.info("Bilateral teleoperation mode enabled")
teleop_arm = teleop_keyboard = None
if isinstance(teleop, list):
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
@@ -196,7 +263,11 @@ def record_loop(
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
loop_count = 0
rerun_log_freq = max(1, int(fps / 10))
while control_time_s is not None and timestamp < control_time_s:
start_loop_t = time.perf_counter()
if events["exit_early"]:
@@ -208,7 +279,67 @@ def record_loop(
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None:
if (
biteleop
and isinstance(robot, SO101FollowerT)
and isinstance(teleop, SO101FollowerT)
and policy is None
):
obs_f = observation # robot is the follower
obs_l = teleop.get_observation()
pos_f = {j: obs_f[f"{j}.pos"] for j in robot.bus.motors}
vel_f = {j: obs_f[f"{j}.vel"] for j in robot.bus.motors}
tau_reaction_f = {j: obs_f[f"{j}.effort"] for j in robot.bus.motors}
pos_l = {j: obs_l[f"{j}.pos"] for j in teleop.bus.motors}
vel_l = {j: obs_l[f"{j}.vel"] for j in teleop.bus.motors}
acc_l = {j: obs_l[f"{j}.acc"] for j in teleop.bus.motors}
tau_reaction_l = {j: obs_l[f"{j}.effort"] for j in teleop.bus.motors}
# Get control gains from robot
kp_gains = robot.kp_gains
kd_gains = robot.kd_gains
kf_gains = robot.kf_gains
# Compute torque commands
tau_cmd_f = [
(
kp_gains[j] * (pos_l[j] - pos_f[j]) # Position tracking
+ kd_gains[j] * (vel_l[j] - vel_f[j]) # Velocity damping
+ kf_gains[j] * (-tau_reaction_l[j] - tau_reaction_f[j])
) # Force reflection
for j in robot.bus.motors
]
tau_cmd_l = [
(
kp_gains[j] * (pos_f[j] - pos_l[j]) # Position tracking
+ kd_gains[j] * (vel_f[j] - vel_l[j]) # Velocity damping
+ kf_gains[j] * (-tau_reaction_f[j] - tau_reaction_l[j])
) # Force reflection
for j in teleop.bus.motors
]
action = {f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(robot.bus.motors)}
teleop_action = {f"{m}.effort": tau_cmd_l[i] for i, m in enumerate(teleop.bus.motors)}
teleop.send_action(teleop_action)
robot.send_action(action)
# For bilateral teleoperation, create custom observation and action for dataset
bilateral_action = {}
for j in teleop.bus.motors:
bilateral_action[f"{j}.pos"] = pos_l[j]
bilateral_action[f"{j}.vel"] = vel_l[j]
bilateral_action[f"{j}.acc"] = acc_l[j]
bilateral_action[f"{j}.effort"] = -tau_reaction_l[j]
# Override the observation_frame and action for dataset recording
if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
action = bilateral_action
elif policy is not None and biteleop and isinstance(robot, SO101FollowerT):
action_values = predict_action(
observation_frame,
policy,
@@ -217,10 +348,57 @@ def record_loop(
task=single_task,
robot_type=robot.robot_type,
)
pos_f = {j: observation[f"{j}.pos"] for j in robot.bus.motors}
vel_f = {j: observation[f"{j}.vel"] for j in robot.bus.motors}
tau_reaction_f = {j: observation[f"{j}.effort"] for j in robot.bus.motors}
# The model returns [pos1, pos2, …, vel1, vel2, …, tau1, tau2, …]
motors = robot.bus.motors # 6 joints
pos_l, vel_l, neg_tau_reaction_l = split_interleaved_action(
action_values, motors
) # The model is trained and returns the effort already as negative: -tau_reaction_l
kp, kd, kf = robot.kp_gains, robot.kd_gains, robot.kf_gains
# Compute torque command for the follower robot
tau_cmd_f = [
(
kp[j] * (pos_l[j] - pos_f[j]) # Position tracking
+ kd[j] * (vel_l[j] - vel_f[j]) # Velocity damping
+ kf[j] * (neg_tau_reaction_l[j] - tau_reaction_f[j]) # Force reflection
)
for j in robot.bus.motors
]
# Format action with calculated torques and send to robot
action = {f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(robot.bus.motors)}
robot.send_action(action)
bilateral_action = {}
for j in robot.bus.motors:
bilateral_action[f"{j}.pos"] = pos_l[j]
bilateral_action[f"{j}.vel"] = vel_l[j]
bilateral_action[f"{j}.effort"] = neg_tau_reaction_l[j]
# Override the observation_frame and action for dataset recording
if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
action = bilateral_action
elif policy is not None and not biteleop:
action_values = predict_action(
observation_frame,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
elif policy is None and isinstance(teleop, Teleoperator):
elif policy is None and isinstance(teleop, Teleoperator) and not biteleop:
action = teleop.get_action()
elif policy is None and isinstance(teleop, list):
elif policy is None and isinstance(teleop, list) and not biteleop:
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
arm_action = teleop_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
@@ -239,31 +417,62 @@ def record_loop(
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
sent_action = robot.send_action(action)
if not biteleop:
sent_action = robot.send_action(action)
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
# For bilateral teleoperation, use the bilateral_action (leader pos & torque)
# For other modes, use sent_action as usual
if biteleop and isinstance(robot, SO101FollowerT):
action_frame = build_dataset_frame(dataset.features, action, prefix="action")
else:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
if display_data:
if display_data and loop_count % rerun_log_freq == 0:
log_rerun_data(observation, action)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
timestamp = time.perf_counter() - start_episode_t
loop_count += 1
@parser.wrap()
def record(cfg: RecordConfig) -> LeRobotDataset:
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
_init_rerun(session_name="recording")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
if cfg.biteleop and cfg.teleop is not None:
print("Bilateral teleoperation enabled")
# For bilateral teleoperation, both arms must be SO101FollowerT robots
from lerobot.robots.so101_follower_torque.config_so101_follower_t import SO101FollowerTConfig
# Check if teleop config has the right type
if cfg.teleop.type != "so101_follower_t":
raise ValueError("Bilateral teleoperation requires teleop.type to be 'so101_follower_t'")
port = getattr(cfg.teleop, "port", None)
if port is None:
raise ValueError("Bilateral teleoperation requires teleop.port to be specified")
teleop_robot_config = SO101FollowerTConfig(
port=port,
id=getattr(cfg.teleop, "id", "leader_arm_torque"),
cameras=getattr(cfg.teleop, "cameras", {}),
disable_torque_on_disconnect=getattr(cfg.teleop, "disable_torque_on_disconnect", True),
)
teleop = SO101FollowerT(teleop_robot_config)
else:
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
@@ -317,6 +526,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
biteleop=cfg.biteleop,
)
# Execute a few seconds without recording to give time to manually reset the environment
@@ -333,6 +543,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
biteleop=cfg.biteleop,
)
if events["rerecord_episode"]:

View File

@@ -25,6 +25,17 @@ python -m lerobot.replay \
--dataset.repo_id=aliberts/record-test \
--dataset.episode=2
```
Biteleop example:
```shell
python -m lerobot.replay \
--robot.type=so101_follower_t \
--robot.port=/dev/tty.usbmodem58760432961 \
--robot.id=follower_arm_torque \
--dataset.repo_id=pepijn223/bilateral-wipe-large \
--dataset.episode=10 \
--biteleop=true
```
"""
import logging
@@ -39,11 +50,14 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
hope_jr,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
so101_follower_torque,
)
from lerobot.robots.so101_follower_torque import SO101FollowerT
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
@@ -69,6 +83,8 @@ class ReplayConfig:
dataset: DatasetReplayConfig
# Use vocal synthesis to read events.
play_sounds: bool = True
# Use biteleop to replay the dataset
biteleop: bool = False
@draccus.wrap()
@@ -79,22 +95,70 @@ def replay(cfg: ReplayConfig):
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
actions = dataset.hf_dataset.select_columns("action")
if cfg.biteleop:
if not isinstance(robot, SO101FollowerT):
raise ValueError(
"Bilateral teleoperation replay requires the robot to be of type SO101FollowerT."
)
log_say("Bilateral teleoperation replay enabled.", cfg.play_sounds)
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
start_time_all = time.perf_counter()
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
start_loop_t = time.perf_counter()
action_array = actions[idx]["action"]
action = {}
action_from_ds_array = actions[idx]["action"]
action_from_ds = {}
for i, name in enumerate(dataset.features["action"]["names"]):
action[name] = action_array[i]
action_from_ds[name] = action_from_ds_array[i]
robot.send_action(action)
# Bilateral teleoperation
if cfg.biteleop:
obs_f = robot.get_observation()
pos_f = {j: obs_f[f"{j}.pos"] for j in robot.bus.motors}
vel_f = {j: obs_f[f"{j}.vel"] for j in robot.bus.motors}
tau_reaction_f = {j: obs_f[f"{j}.effort"] for j in robot.bus.motors}
dt_s = time.perf_counter() - start_episode_t
pos_l = {j: action_from_ds[f"{j}.pos"] for j in robot.bus.motors}
vel_l = {j: action_from_ds[f"{j}.vel"] for j in robot.bus.motors}
# The saved effort in dataset is -tau_reaction_l
neg_tau_reaction_l = {j: action_from_ds[f"{j}.effort"] for j in robot.bus.motors}
# Get control gains from the robot instance
kp_gains = robot.kp_gains
kd_gains = robot.kd_gains
kf_gains = robot.kf_gains
# Compute torque command for the follower robot
tau_cmd_f = [
(
kp_gains[j] * (pos_l[j] - pos_f[j]) # Position tracking
+ kd_gains[j] * (vel_l[j] - vel_f[j]) # Velocity damping
+ kf_gains[j] * (neg_tau_reaction_l[j] - tau_reaction_f[j]) # Force reflection
)
for j in robot.bus.motors
]
# Format action with calculated torques and send to robot
action_to_send = {f"{m}.effort": tau_cmd_f[i] for i, m in enumerate(robot.bus.motors)}
robot.send_action(action_to_send)
else:
# Original logic for standard position-based replay
robot.send_action(action_from_ds)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / dataset.fps - dt_s)
total_time = time.perf_counter() - start_time_all
actual_fps = idx / total_time if total_time > 0 else float("inf")
logging.info(f"Average FPS achieved over episode: {actual_fps:.2f}")
log_say(f"Average FPS achieved: {actual_fps:.2f}", cfg.play_sounds)
robot.disconnect()

View File

@@ -0,0 +1,3 @@
from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig
from .hope_jr_arm import HopeJrArm
from .hope_jr_hand import HopeJrHand

View File

@@ -0,0 +1,51 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("hope_jr_hand")
@dataclass
class HopeJrHandConfig(RobotConfig):
port: str # Port to connect to the hand
side: str # "left" / "right"
disable_torque_on_disconnect: bool = True
cameras: dict[str, CameraConfig] = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
if self.side not in ["right", "left"]:
raise ValueError(self.side)
@RobotConfig.register_subclass("hope_jr_arm")
@dataclass
class HopeJrArmConfig(RobotConfig):
port: str # Port to connect to the hand
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -0,0 +1,268 @@
# HopeJR
## Prerequisites
- [Hardware Setup](https://github.com/TheRobotStudio/HOPEJr)
## Install LeRobot
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
Install LeRobot with HopeJR dependencies:
```bash
pip install -e ".[hopejr]"
```
## Device Configuration
Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton:
```bash
python -m lerobot.find_port
```
This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts.
## Step 1: Calibration
Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibration files will be saved in `~/.cache/huggingface/lerobot/calibration`
### 1.1 Calibrate Robot Hand
```bash
python -m lerobot.calibrate \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=blue \
--robot.side=right
```
When running the calibration script, a calibration GUI will pop up. Finger joints are named as follows:
**Thumb**:
- **CMC**: base joint connecting thumb to hand
- **MCP**: knuckle joint
- **PIP**: first finger joint
- **DIP** : fingertip joint
**Index, Middle, Ring, and Pinky fingers**:
- **Radial flexor**: Moves base of finger towards the thumb
- **Ulnar flexor**: Moves base of finger towards the pinky
- **PIP/DIP**: Flexes the distal and proximal phalanx of the finger
Each one of these will need to be calibrated individually via the GUI.
Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/calibration_gui_1.png"
alt="Setting boundaries in the hand calibration GUI"
title="Setting boundaries in the hand calibration GUI"
width="100%">
</img>
</p>
Use the calibration interface to set the range boundaries for each joint as shown above.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/calibration_gui_2.png"
alt="Saving calibration values"
title="Saving calibration values"
width="100%">
</img>
</p>
Once you have set the appropriate boundaries for all joints, click "Save" to save the calibration values to the motors.
### 1.2 Calibrate Teleoperator Glove
```bash
python -m lerobot.calibrate \
--teleop.type=homunculus_glove \
--teleop.port=/dev/tty.usbmodem11201 \
--teleop.id=red \
--teleop.side=right
```
Move each finger through its full range of motion, starting from the thumb.
```
Move thumb through its entire range of motion.
Recording positions. Press ENTER to stop...
-------------------------------------------
NAME | MIN | POS | MAX
thumb_cmc | 1790 | 1831 | 1853
thumb_mcp | 1497 | 1514 | 1528
thumb_pip | 1466 | 1496 | 1515
thumb_dip | 1463 | 1484 | 1514
```
Continue with each finger:
```
Move middle through its entire range of motion.
Recording positions. Press ENTER to stop...
-------------------------------------------
NAME | MIN | POS | MAX
middle_mcp_abduction | 1598 | 1718 | 1820
middle_mcp_flexion | 1512 | 1658 | 2136
middle_dip | 1484 | 1500 | 1547
```
Once calibration is complete, the system will save the calibration to `/Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_glove/red.json`
### 1.3 Calibrate Robot Arm
```bash
python -m lerobot.calibrate \
--robot.type=hope_jr_arm \
--robot.port=/dev/tty.usbserial-1110 \
--robot.id=white
```
This will open a calibration GUI where you can set the range limits for each motor. The arm motions are organized as follows:
- **Shoulder**: pitch, yaw, and roll
- **Elbow**: flex
- **Wrist**: pitch, yaw, and roll
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/calibration_gui_2.png"
alt="Setting boundaries in the arm calibration GUI"
title="Setting boundaries in the arm calibration GUI"
width="100%">
</img>
</p>
Use the calibration interface to set the range boundaries for each joint. Move each joint through its full range of motion and adjust the minimum and maximum values accordingly. Once you have set the appropriate boundaries for all joints, save the calibration.
### 1.4 Calibrate Teleoperator Exoskeleton
```bash
python -m lerobot.calibrate \
--teleop.type=homunculus_arm \
--teleop.port=/dev/tty.usbmodem11201 \
--teleop.id=black
```
The exoskeleton allows one to control the robot arm. During calibration, you'll be prompted to move all joints through their full range of motion:
```
Move all joints through their entire range of motion.
Recording positions. Press ENTER to stop...
-------------------------------------------
-------------------------------------------
NAME | MIN | POS | MAX
shoulder_pitch | 586 | 736 | 895
shoulder_yaw | 1257 | 1374 | 1390
shoulder_roll | 449 | 1034 | 2564
elbow_flex | 3023 | 3117 | 3134
wrist_roll | 3073 | 3096 | 3147
wrist_yaw | 2143 | 2171 | 2185
wrist_pitch | 1975 | 1993 | 2074
Calibration saved to /Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_arm/black.json
```
## Step 2: Teleoperation
Due to global variable conflicts in the Feetech middleware, teleoperation for arm and hand must run in separate shell sessions:
### Hand
```bash
python -m lerobot.teleoperate \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=blue \
--robot.side=right \
--teleop.type=homunculus_glove \
--teleop.port=/dev/tty.usbmodem11201 \
--teleop.id=red \
--teleop.side=right \
--display_data=true \
--fps=30
```
### Arm
```bash
python -m lerobot.teleoperate \
--robot.type=hope_jr_arm \
--robot.port=/dev/tty.usbserial-1110 \
--robot.id=white \
--teleop.type=homunculus_arm \
--teleop.port=/dev/tty.usbmodem11201 \
--teleop.id=black \
--display_data=true \
--fps=30
```
## Step 3: Record, Replay, Train
Record, Replay and Train with Hope-JR is still experimental.
### Record
This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings).
```bash
python -m lerobot.record \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
--robot.side=right \
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=homunculus_glove \
--teleop.port=/dev/tty.usbmodem1201 \
--teleop.id=right \
--teleop.side=right \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--dataset.single_task="Hand recording test with video data" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
### Replay
```bash
python -m lerobot.replay \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
--robot.side=right \
--dataset.repo_id=nepyope/hand_record_test_with_camera \
--dataset.episode=0
```
### Train
```bash
python -m lerobot.scripts.train \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--policy.type=act \
--output_dir=outputs/train/hopejr_hand \
--job_name=hopejr \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=nepyope/hand_test_policy
```
### Evaluate
This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino).
```bash
python -m lerobot.record \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
--robot.side=right \
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
--display_data=false \
--dataset.repo_id=nepyope/eval_hopejr \
--dataset.single_task="Evaluate hopejr hand policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
```

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_hope_jr import HopeJrArmConfig
logger = logging.getLogger(__name__)
class HopeJrArm(Robot):
config_class = HopeJrArmConfig
name = "hope_jr_arm"
def __init__(self, config: HopeJrArmConfig):
super().__init__(config)
self.config = config
self.bus = FeetechMotorsBus(
port=self.config.port,
motors={
"shoulder_pitch": Motor(1, "sm8512bl", MotorNormMode.RANGE_M100_100),
"shoulder_yaw": Motor(2, "sts3250", MotorNormMode.RANGE_M100_100),
"shoulder_roll": Motor(3, "sts3250", MotorNormMode.RANGE_M100_100),
"elbow_flex": Motor(4, "sts3250", MotorNormMode.RANGE_M100_100),
"wrist_roll": Motor(5, "sts3250", MotorNormMode.RANGE_M100_100),
"wrist_yaw": Motor(6, "sts3250", MotorNormMode.RANGE_M100_100),
"wrist_pitch": Motor(7, "sts3250", MotorNormMode.RANGE_M100_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
# HACK
self.shoulder_pitch = "shoulder_pitch"
self.other_motors = [m for m in self.bus.motors if m != "shoulder_pitch"]
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
def connect(self, calibrate: bool = True) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect(handshake=False)
if not self.is_calibrated and calibrate:
self.calibrate()
# Connect the cameras
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self, limb_name: str = None) -> None:
groups = {
"all": list(self.bus.motors.keys()),
"shoulder": ["shoulder_pitch", "shoulder_yaw", "shoulder_roll"],
"elbow": ["elbow_flex"],
"wrist": ["wrist_roll", "wrist_yaw", "wrist_pitch"],
}
self.calibration = RangeFinderGUI(self.bus, groups).run()
self._save_calibration()
print("Calibration saved to", self.calibration_fpath)
def configure(self) -> None:
with self.bus.torque_disabled():
self.bus.configure_motors(maximum_acceleration=30, acceleration=30)
def setup_motors(self) -> None:
# TODO: add docstring
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
obs_dict[self.shoulder_pitch] = self.bus.read("Present_Position", self.shoulder_pitch)
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.bus.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,200 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from ..robot import Robot
from .config_hope_jr import HopeJrHandConfig
logger = logging.getLogger(__name__)
RIGHT_HAND_INVERSIONS = [
"thumb_mcp",
"thumb_dip",
"index_ulnar_flexor",
"middle_ulnar_flexor",
"ring_ulnar_flexor",
"ring_pip_dip",
"pinky_ulnar_flexor",
"pinky_pip_dip",
]
LEFT_HAND_INVERSIONS = [
"thumb_cmc",
"thumb_mcp",
"thumb_dip",
"index_radial_flexor",
"index_pip_dip",
"middle_radial_flexor",
"middle_pip_dip",
"ring_radial_flexor",
"ring_pip_dip",
"pinky_radial_flexor",
# "pinky_pip_dip",
]
class HopeJrHand(Robot):
config_class = HopeJrHandConfig
name = "hope_jr_hand"
def __init__(self, config: HopeJrHandConfig):
super().__init__(config)
self.config = config
self.bus = FeetechMotorsBus(
port=self.config.port,
motors={
# Thumb
"thumb_cmc": Motor(1, "scs0009", MotorNormMode.RANGE_0_100),
"thumb_mcp": Motor(2, "scs0009", MotorNormMode.RANGE_0_100),
"thumb_pip": Motor(3, "scs0009", MotorNormMode.RANGE_0_100),
"thumb_dip": Motor(4, "scs0009", MotorNormMode.RANGE_0_100),
# Index
"index_radial_flexor": Motor(5, "scs0009", MotorNormMode.RANGE_0_100),
"index_ulnar_flexor": Motor(6, "scs0009", MotorNormMode.RANGE_0_100),
"index_pip_dip": Motor(7, "scs0009", MotorNormMode.RANGE_0_100),
# Middle
"middle_radial_flexor": Motor(8, "scs0009", MotorNormMode.RANGE_0_100),
"middle_ulnar_flexor": Motor(9, "scs0009", MotorNormMode.RANGE_0_100),
"middle_pip_dip": Motor(10, "scs0009", MotorNormMode.RANGE_0_100),
# Ring
"ring_radial_flexor": Motor(11, "scs0009", MotorNormMode.RANGE_0_100),
"ring_ulnar_flexor": Motor(12, "scs0009", MotorNormMode.RANGE_0_100),
"ring_pip_dip": Motor(13, "scs0009", MotorNormMode.RANGE_0_100),
# Pinky
"pinky_radial_flexor": Motor(14, "scs0009", MotorNormMode.RANGE_0_100),
"pinky_ulnar_flexor": Motor(15, "scs0009", MotorNormMode.RANGE_0_100),
"pinky_pip_dip": Motor(16, "scs0009", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
protocol_version=1,
)
self.cameras = make_cameras_from_configs(config.cameras)
self.inverted_motors = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
self.calibrate()
# Connect the cameras
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
fingers = {}
for finger in ["thumb", "index", "middle", "ring", "pinky"]:
fingers[finger] = [motor for motor in self.bus.motors if motor.startswith(finger)]
self.calibration = RangeFinderGUI(self.bus, fingers).run()
for motor in self.inverted_motors:
self.calibration[motor].drive_mode = 1
self._save_calibration()
print("Calibration saved to", self.calibration_fpath)
def configure(self) -> None:
with self.bus.torque_disabled():
self.bus.configure_motors()
def setup_motors(self) -> None:
# TODO: add docstring
for motor in self.bus.motors:
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict = {}
# Read hand position
start = time.perf_counter()
for motor in self.bus.motors:
obs_dict[f"{motor}.pos"] = self.bus.read("Present_Position", motor)
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
self.bus.sync_write("Goal_Position", goal_pos)
return action
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,2 @@
from .config_so101_follower_t import SO101FollowerTConfig
from .so101_follower_t import SO101FollowerT

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("so101_follower_t")
@dataclass
class SO101FollowerTConfig(RobotConfig):
# Port to connect to the arm
port: str
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -0,0 +1,553 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import time
from functools import cached_property
from typing import Any
import numpy as np
import pinocchio as pin
from scipy.signal import butter, lfilter
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from ..robot import Robot
from .config_so101_follower_t import SO101FollowerTConfig
logger = logging.getLogger(__name__)
class SO101FollowerT(Robot):
"""
SO-101 Arm with HLS3625 motors with current control.
"""
config_class = SO101FollowerTConfig
name = "so101_follower_t"
_CURRENT_STEP_A: float = 6.5e-3 # 6.5 mA per register LSB #http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0
_KT_NM_PER_AMP: float = 0.814 # Torque constant Kt [N·m/A] #https://www.feetechrc.com/811177.html
_MAX_CURRENT_A: float = 4.0 # Safe driver limit
# Position gains
_KP_GAINS = {
"shoulder_pan": 5.0,
"shoulder_lift": 7.0,
"elbow_flex": 7.0,
"wrist_flex": 5.0,
"wrist_roll": 5.0,
"gripper": 5.0,
}
# Velocity gains
_KD_GAINS = {
"shoulder_pan": 0.4,
"shoulder_lift": 0.6,
"elbow_flex": 0.6,
"wrist_flex": 0.4,
"wrist_roll": 0.4,
"gripper": 0.4,
}
# Force gains
_KF_GAINS = {
"shoulder_pan": 0.05,
"shoulder_lift": 0.05,
"elbow_flex": 0.05,
"wrist_flex": 0.05,
"wrist_roll": 0.05,
"gripper": 0.05,
}
# Viscous friction coefficient
_FRICTION_VISCOUS = {
"shoulder_pan": 0.05,
"shoulder_lift": 0.08,
"elbow_flex": 0.05,
"wrist_flex": 0.05,
"wrist_roll": 0.05,
"gripper": 0.05,
}
# Coulomb/static friction
_FRICTION_COULOMB = {
"shoulder_pan": 0.15,
"shoulder_lift": 0.25,
"elbow_flex": 0.25,
"wrist_flex": 0.20,
"wrist_roll": 0.20,
"gripper": 0.20,
}
def __init__(self, config: SO101FollowerTConfig):
super().__init__(config)
self.config = config
if self.calibration_fpath.is_file() and not self.calibration:
self._load_calibration()
self.bus = FeetechMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "hls3625", MotorNormMode.DEGREES),
"shoulder_lift": Motor(2, "hls3625", MotorNormMode.DEGREES),
"elbow_flex": Motor(3, "hls3625", MotorNormMode.DEGREES),
"wrist_flex": Motor(4, "hls3625", MotorNormMode.DEGREES),
"wrist_roll": Motor(5, "hls3625", MotorNormMode.DEGREES),
"gripper": Motor(6, "hls3625", MotorNormMode.DEGREES),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
self.pin_robot = pin.RobotWrapper.BuildFromURDF("urdf/so101_new_calib.urdf", "urdf")
flip = {
"shoulder_pan": True,
"shoulder_lift": True,
"elbow_flex": True,
"wrist_flex": True,
"wrist_roll": True,
"gripper": True,
}
self.torque_sign = {n: (-1.0 if flip[n] else 1.0) for n in self.bus.motors}
self._prev_pos_rad: dict[str, float] | None = None
self._prev_vel_rad: dict[str, float] | None = None
self._prev_t: float | None = None
# Butterworth low-pass filter parameters
self._cutoff_freq = 10.0 # Hz, cutoff frequency for the filter
self._filter_order = 2 # Filter order
self._sampling_freq = 100.0 # Hz, (control loop frequency)
nyquist_freq = self._sampling_freq / 2
normalized_cutoff = self._cutoff_freq / nyquist_freq
self._b, self._a = butter(self._filter_order, normalized_cutoff, btype="low")
# History buffers
self._pos_history = {m: collections.deque(maxlen=20) for m in self.bus.motors}
self._vel_raw_history = {m: collections.deque(maxlen=20) for m in self.bus.motors}
self._time_history = collections.deque(maxlen=20)
self._last_observation = None
@property
def _motors_ft(self) -> dict[str, type]:
d: dict[str, type] = {}
for motor in self.bus.motors:
d[f"{motor}.pos"] = float
d[f"{motor}.vel"] = float
d[f"{motor}.effort"] = float
return d
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
d: dict[str, type] = {}
for motor in self.bus.motors:
d[f"{motor}.pos"] = float
d[f"{motor}.vel"] = float
d[f"{motor}.effort"] = float
return d
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@property
def kp_gains(self) -> dict[str, float]:
"""Position control gains [Nm/rad] for bilateral teleoperation"""
return self._KP_GAINS.copy()
@property
def kd_gains(self) -> dict[str, float]:
"""Velocity control gains [Nm⋅s/rad] for bilateral teleoperation"""
return self._KD_GAINS.copy()
@property
def kf_gains(self) -> dict[str, float]:
"""Force control gains for bilateral teleoperation"""
return self._KF_GAINS.copy()
@property
def friction_viscous(self) -> dict[str, float]:
"""Viscous friction coefficients [Nm⋅s/rad] for friction compensation"""
return self._FRICTION_VISCOUS.copy()
@property
def friction_coulomb(self) -> dict[str, float]:
"""Coulomb friction coefficients [Nm] for friction compensation"""
return self._FRICTION_COULOMB.copy()
def set_butterworth_params(self, cutoff_freq: float = 10.0, order: int = 2) -> None:
"""Configure Butterworth low-pass filter parameters for velocity/acceleration estimation.
Args:
cutoff_freq: Cutoff frequency in Hz (default: 10 Hz)
order: Filter order (default: 2)
"""
if cutoff_freq <= 0:
raise ValueError("Cutoff frequency must be positive")
if cutoff_freq >= self._sampling_freq / 2:
raise ValueError(
f"Cutoff frequency must be less than Nyquist frequency ({self._sampling_freq / 2} Hz)"
)
if order < 1:
raise ValueError("Filter order must be at least 1")
self._cutoff_freq = cutoff_freq
self._filter_order = order
nyquist_freq = self._sampling_freq / 2
normalized_cutoff = self._cutoff_freq / nyquist_freq
self._b, self._a = butter(self._filter_order, normalized_cutoff, btype="low")
# Clear buffers
for m in self.bus.motors:
self._pos_history[m].clear()
self._vel_raw_history[m].clear()
self._time_history.clear()
logger.info(f"Butterworth filter updated: cutoff_freq={cutoff_freq} Hz, order={order}")
def _current_to_torque_nm(self, raw: dict[str, Any]) -> dict[str, float]:
"""Convert "Present_Current" register counts (±2047) → torque [Nm].
Values are clamped to ±3A before conversion for protection.
"""
max_cnt = int(round(self._MAX_CURRENT_A / self._CURRENT_STEP_A)) # ≈ 462
coef = self._CURRENT_STEP_A * self._KT_NM_PER_AMP
return {k: self.torque_sign[k] * max(-max_cnt, min(max_cnt, v)) * coef for k, v in raw.items()}
def _torque_nm_to_current(self, torque: dict[str, float]) -> dict[str, int]:
"""Convert torque [Nm] to register counts, clamped to ±3A (2.44 Nm)."""
inv_coef = 1.0 / (self._CURRENT_STEP_A * self._KT_NM_PER_AMP)
max_cnt = int(round(self._MAX_CURRENT_A / self._CURRENT_STEP_A))
counts = {}
for k, τ in torque.items():
cnt = τ * self.torque_sign[k] * inv_coef
cnt = max(-max_cnt, min(max_cnt, cnt))
counts[k] = int(round(cnt))
return counts
def _deg_to_rad(self, deg: dict[str, float | int]) -> dict[str, float]:
"""Degrees to radians."""
return {m: np.deg2rad(float(v)) for m, v in deg.items()}
def _gravity_from_q(self, q_rad: dict[str, float]) -> dict[str, float]:
"""
Compute g(q) [N m] for all joints in the robot.
The order of joints in the URDF matches self.bus.motors.
"""
q = np.zeros(self.pin_robot.model.nq)
for i, motor_name in enumerate(self.bus.motors):
q[i] = q_rad[motor_name]
g = pin.computeGeneralizedGravity(self.pin_robot.model, self.pin_robot.data, q)
return {motor_name: float(g[i]) for i, motor_name in enumerate(self.bus.motors)}
def _inertia_from_q_dq(
self, q_rad: dict[str, float], dq_rad: dict[str, float], ddq_rad: dict[str, float]
) -> dict[str, float]:
"""
Compute inertia torques τ_inertia = M(q) * ddq directly from URDF model.
"""
q = np.zeros(self.pin_robot.model.nq)
dq = np.zeros(self.pin_robot.model.nv)
ddq = np.zeros(self.pin_robot.model.nv)
for i, motor_name in enumerate(self.bus.motors):
q[i] = q_rad[motor_name]
dq[i] = dq_rad[motor_name]
ddq[i] = ddq_rad[motor_name]
# Compute mass matrix M(q)
mass_matrix = pin.crba(self.pin_robot.model, self.pin_robot.data, q)
# Compute inertia torques: τ_inertia = M(q) * ddq
tau_inertia = mass_matrix @ ddq
return {motor_name: float(tau_inertia[i]) for i, motor_name in enumerate(self.bus.motors)}
def _compute_model_based_disturbance(
self,
q_rad: dict[str, float],
dq_rad: dict[str, float],
ddq_rad: dict[str, float],
tau_measured: dict[str, float],
) -> dict[str, float]:
"""
Compute disturbance torques using direct model-based approach:
τ_disturbance = τ_measured - τ_gravity - τ_inertia - τ_friction
Args:
include_friction: If True, also removes friction from the disturbance calculation
"""
tau_gravity = self._gravity_from_q(q_rad)
tau_inertia = self._inertia_from_q_dq(q_rad, dq_rad, ddq_rad)
# Compute disturbance
tau_disturbance = {}
tau_friction = {}
for motor_name in self.bus.motors:
tau_dist = tau_measured[motor_name] - tau_gravity[motor_name] - tau_inertia[motor_name]
# Calculate friction torque
omega = dq_rad[motor_name]
tau_friction_motor = self._FRICTION_VISCOUS[motor_name] * omega + self._FRICTION_COULOMB[
motor_name
] * (1.0 if omega > 0.01 else -1.0 if omega < -0.01 else 0.0)
# Apply torque sign correction
tau_friction_motor = -tau_friction_motor
tau_friction[motor_name] = tau_friction_motor
tau_dist -= tau_friction_motor
tau_disturbance[motor_name] = tau_dist
return tau_disturbance
def connect(self, calibrate: bool = True) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Ensure calibration is loaded from file if it exists
if self.calibration_fpath.is_file() and not self.calibration:
self._load_calibration()
# Update the bus with the loaded calibration
self.bus.calibration = self.calibration
self.bus.connect()
if not self.is_calibrated and calibrate:
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
# Check if calibration file exists and is loaded
return self.calibration_fpath.is_file() and bool(self.calibration)
def calibrate(self) -> None:
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, 2, num_retry=2) # Set to current mode
input(f"Move {self} to the middle of its range of motion and press ENTER....")
homing_offsets = self.bus.set_half_turn_homings()
print(
"Move all joints sequentially through their entire ranges "
"of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.bus.record_ranges_of_motion()
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=int(homing_offsets[motor]),
range_min=int(range_mins[motor]),
range_max=int(range_maxes[motor]),
)
# Update the bus calibration with the new values
self.bus.calibration = self.calibration
# Save calibration to file only
self._save_calibration()
print("Calibration saved to", self.calibration_fpath)
def configure(self) -> None:
self.bus.disable_torque() # here was issue at startup previously
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, 2, num_retry=2) # Set to current mode
self.bus.write("Present_Current", motor, 0, normalize=False, num_retry=5)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
t_now = time.perf_counter()
# Position
pos_deg = self.bus.sync_read("Present_Position", num_retry=5)
pos_rad = self._deg_to_rad(pos_deg)
# Store position and time history
for m in pos_rad:
self._pos_history[m].append(pos_rad[m])
self._time_history.append(t_now)
# Calculate raw velocity
vel_rad_raw = {}
if self._prev_pos_rad is None or self._prev_t is None:
vel_rad_raw = dict.fromkeys(pos_rad, 0.0)
else:
dt = t_now - self._prev_t
dt = max(dt, 1e-4) # Avoid division by zero
vel_rad_raw = {m: (pos_rad[m] - self._prev_pos_rad[m]) / dt for m in pos_rad}
# Store raw velocity history
for m in vel_rad_raw:
self._vel_raw_history[m].append(vel_rad_raw[m])
# Apply Butterworth low-pass filter to velocity
vel_rad = {}
for m in pos_rad:
if len(self._vel_raw_history[m]) >= 10:
vel_raw_array = np.array(list(self._vel_raw_history[m]))
# Apply Butterworth filter
vel_filtered = lfilter(self._b, self._a, vel_raw_array)
vel_rad[m] = vel_filtered[-1]
else:
vel_rad[m] = vel_rad_raw[m]
# Calculate acceleration from filtered velocity
acc_rad = {}
if self._prev_vel_rad is None or self._prev_t is None:
acc_rad = dict.fromkeys(pos_rad, 0.0)
else:
dt = t_now - self._prev_t
dt = max(dt, 1e-4) # Avoid division by zero
acc_rad = {m: (vel_rad[m] - self._prev_vel_rad[m]) / dt for m in vel_rad}
self._prev_pos_rad = pos_rad.copy()
self._prev_vel_rad = vel_rad.copy()
self._prev_t = t_now
# Measured torque (Nm)
cur_raw = self.bus.sync_read("Present_Current", normalize=False, num_retry=5)
tau_meas = self._current_to_torque_nm(cur_raw)
# Compute reaction torques using model-based approach
tau_reaction = self._compute_model_based_disturbance(pos_rad, vel_rad, acc_rad, tau_meas)
obs_dict = {}
obs_dict |= {f"{m}.pos": pos_rad[m] for m in self.bus.motors}
obs_dict |= {f"{m}.vel": vel_rad[m] for m in self.bus.motors}
obs_dict |= {f"{m}.acc": acc_rad[m] for m in self.bus.motors}
obs_dict |= {f"{m}.effort": tau_reaction[m] for m in self.bus.motors}
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
# Store observation for feedforward compensation
self._last_observation = obs_dict.copy()
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""Command arm to move to a target torque for a joint.
Raises:
RobotDeviceNotConnectedError: if robot is not connected.
Returns:
the action sent to the motors.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Extract torque commands
tau_cmd_nm = {k.removesuffix(".effort"): float(v) for k, v in action.items() if k.endswith(".effort")}
if not tau_cmd_nm:
return action
# Add feedforward compensation if we have a last observation
if self._last_observation is not None:
# Extract position, velocity, acceleration from last observation
pos_rad = {m: self._last_observation[f"{m}.pos"] for m in self.bus.motors}
vel_rad = {m: self._last_observation[f"{m}.vel"] for m in self.bus.motors}
acc_rad = {m: self._last_observation[f"{m}.acc"] for m in self.bus.motors}
# Compute feedforward terms
tau_gravity = self._gravity_from_q(pos_rad)
tau_inertia = self._inertia_from_q_dq(pos_rad, vel_rad, acc_rad)
# Add feedforward compensation to commanded torques
for motor in tau_cmd_nm:
# Add gravity compensation
tau_cmd_nm[motor] += tau_gravity[motor]
# Add inertia compensation
tau_cmd_nm[motor] += tau_inertia[motor]
# Add friction compensation
omega = vel_rad[motor]
tau_friction = self._FRICTION_VISCOUS[motor] * omega + self._FRICTION_COULOMB[motor] * (
1.0 if omega > 0.01 else -1.0 if omega < -0.01 else 0.0
)
tau_friction = -tau_friction # Apply torque sign correction
tau_cmd_nm[motor] += tau_friction
inv_coef = 1.0 / (self._CURRENT_STEP_A * self._KT_NM_PER_AMP)
max_cnt = int(round(self._MAX_CURRENT_A / self._CURRENT_STEP_A))
counts = {}
for joint, τ in tau_cmd_nm.items():
cnt = τ * self.torque_sign[joint] * inv_coef # flip SIGN
cnt = max(-max_cnt, min(max_cnt, cnt))
counts[joint] = int(round(cnt))
self.bus.sync_write("Target_Torque", counts, normalize=False, num_retry=2)
self._last_cmd_nm = tau_cmd_nm
return action
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -37,6 +37,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .so101_follower import SO101Follower
return SO101Follower(config)
elif config.type == "so101_follower_t":
from .so101_follower_torque import SO101FollowerT
return SO101FollowerT(config)
elif config.type == "lekiwi":
from .lekiwi import LeKiwi
@@ -49,6 +53,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .viperx import ViperX
return ViperX(config)
elif config.type == "hope_jr_hand":
from .hope_jr import HopeJrHand
return HopeJrHand(config)
elif config.type == "hope_jr_arm":
from .hope_jr import HopeJrArm
return HopeJrArm(config)
elif config.type == "mock_robot":
from tests.mocks.mock_robot import MockRobot

View File

@@ -317,7 +317,7 @@ def act_with_policy(
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -642,9 +642,29 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
state_dict = bytes_to_state_dict(bytes_state_dict)
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
state_dicts = bytes_to_state_dict(bytes_state_dict)
# TODO: check encoder parameter synchronization possible issues:
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
# instead of the updated encoder params from critic (which is optimized separately)
# 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params
# 3. Need to handle encoder params correctly for both actor and discrete_critic
# Potential fixes:
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
#################################################

View File

@@ -1109,8 +1109,18 @@ def check_nan_in_transition(
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
state_bytes = state_to_bytes(state_dict)
# Create a dictionary to hold all the state dicts
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
# Add discrete critic if it exists
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
policy.discrete_critic.state_dict(), device="cpu"
)
logging.debug("[LEARNER] Including discrete critic in state dict push")
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)

View File

@@ -0,0 +1,197 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Callable
import torch
from lerobot.robots.config import RobotConfig
from lerobot.scripts.server.constants import (
DEFAULT_FPS,
DEFAULT_INFERENCE_LATENCY,
DEFAULT_OBS_QUEUE_TIMEOUT,
)
# Aggregate function registry for CLI usage
AGGREGATE_FUNCTIONS = {
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
"latest_only": lambda old, new: new,
"average": lambda old, new: 0.5 * old + 0.5 * new,
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
}
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Get aggregate function by name from registry."""
if name not in AGGREGATE_FUNCTIONS:
available = list(AGGREGATE_FUNCTIONS.keys())
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
return AGGREGATE_FUNCTIONS[name]
@dataclass
class PolicyServerConfig:
"""Configuration for PolicyServer.
This class defines all configurable parameters for the PolicyServer,
including networking settings and action chunking specifications.
"""
# Networking configuration
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
# Timing configuration
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
inference_latency: float = field(
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
)
obs_queue_timeout: float = field(
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
)
def __post_init__(self):
"""Validate configuration after initialization."""
if self.port < 1 or self.port > 65535:
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
if self.environment_dt <= 0:
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
if self.inference_latency < 0:
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
if self.obs_queue_timeout < 0:
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
@classmethod
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
"""Create a PolicyServerConfig from a dictionary."""
return cls(**config_dict)
@property
def environment_dt(self) -> float:
"""Environment time step, in seconds"""
return 1 / self.fps
def to_dict(self) -> dict:
"""Convert the configuration to a dictionary."""
return {
"host": self.host,
"port": self.port,
"fps": self.fps,
"environment_dt": self.environment_dt,
"inference_latency": self.inference_latency,
}
@dataclass
class RobotClientConfig:
"""Configuration for RobotClient.
This class defines all configurable parameters for the RobotClient,
including network connection, policy settings, and control behavior.
"""
# Policy configuration
policy_type: str = field(metadata={"help": "Type of policy to use"})
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
# Robot configuration (for CLI usage - robot instance will be created from this)
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
# Network configuration
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
# Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
# Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
# Aggregate function configuration (CLI-compatible)
aggregate_fn_name: str = field(
default="weighted_average",
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
)
# Debug configuration
debug_visualize_queue_size: bool = field(
default=False, metadata={"help": "Visualize the action queue size"}
)
# Verification configuration
verify_robot_cameras: bool = field(
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
)
@property
def environment_dt(self) -> float:
"""Environment time step, in seconds"""
return 1 / self.fps
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.server_address:
raise ValueError("server_address cannot be empty")
if not self.policy_type:
raise ValueError("policy_type cannot be empty")
if not self.pretrained_name_or_path:
raise ValueError("pretrained_name_or_path cannot be empty")
if not self.policy_device:
raise ValueError("policy_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
if self.fps <= 0:
raise ValueError(f"fps must be positive, got {self.fps}")
if self.actions_per_chunk <= 0:
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
@classmethod
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
"""Create a RobotClientConfig from a dictionary."""
return cls(**config_dict)
def to_dict(self) -> dict:
"""Convert the configuration to a dictionary."""
return {
"server_address": self.server_address,
"policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device,
"chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps,
"actions_per_chunk": self.actions_per_chunk,
"task": self.task,
"debug_visualize_queue_size": self.debug_visualize_queue_size,
"aggregate_fn_name": self.aggregate_fn_name,
}

View File

@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
DEFAULT_FPS = 30
"""Server side: Running inference on (at most) 1/fps"""
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
"""Server side: Timeout for observation queue in seconds"""
DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]

View File

@@ -0,0 +1,386 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Event
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot
from lerobot.transport import async_inference_pb2
from lerobot.transport.utils import bytes_buffer_size
from lerobot.utils.utils import init_logging
Action = torch.Tensor
ActionChunk = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
# observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor]
# observation, ready for policy inference (image keys resized)
Observation = dict[str, torch.Tensor]
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.set_title("Action Queue Size Over Time")
ax.set_xlabel("Environment steps")
ax.set_ylabel("Action Queue Size")
ax.set_ylim(0, max(action_queue_size) * 1.1)
ax.grid(True, alpha=0.3)
ax.plot(range(len(action_queue_size)), action_queue_size)
plt.show()
def validate_robot_cameras_for_policy(
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
) -> None:
image_keys = list(filter(is_image_key, lerobot_observation_features))
assert set(image_keys) == set(policy_image_features.keys()), (
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
)
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
image = image.permute(2, 0, 1)
dims = (resize_dims[1], resize_dims[2])
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
image_batched = image.unsqueeze(0)
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
return resized.squeeze(0)
def raw_observation_to_observation(
raw_observation: RawObservation,
lerobot_features: dict[str, dict],
policy_image_features: dict[str, PolicyFeature],
device: str,
) -> Observation:
observation = {}
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
for k, v in observation.items():
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
if "image" in k:
# Policy expects images in shape (B, C, H, W)
observation[k] = prepare_image(v).unsqueeze(0).to(device)
else:
observation[k] = v.to(device)
else:
observation[k] = v
return observation
def prepare_image(image: torch.Tensor) -> torch.Tensor:
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
image = image.type(torch.float32) / 255
image = image.contiguous()
return image
def extract_state_from_raw_observation(
lerobot_obs: RawObservation,
) -> torch.Tensor:
"""Extract the state from a raw observation."""
state = torch.tensor(lerobot_obs[OBS_STATE])
if state.ndim == 1:
state = state.unsqueeze(0)
return state
def extract_images_from_raw_observation(
lerobot_obs: RawObservation,
camera_key: str,
) -> dict[str, torch.Tensor]:
"""Extract the images from a raw observation."""
return torch.tensor(lerobot_obs[camera_key])
def make_lerobot_observation(
robot_obs: RawObservation,
lerobot_features: dict[str, dict],
) -> LeRobotObservation:
"""Make a lerobot observation from a raw observation."""
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
def prepare_raw_observation(
robot_obs: RawObservation,
lerobot_features: dict[str, dict],
policy_image_features: dict[str, PolicyFeature],
) -> Observation:
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
policy_image_features)."""
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
# 2. Greps all observation.images.<> keys
image_keys = list(filter(is_image_key, lerobot_obs))
# state's shape is expected as (B, state_dim)
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
image_dict = {
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
}
# Turns the image features to (C, H, W) with H, W matching the policy image features.
# This reduces the resolution of the images
image_dict = {
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
for key in image_keys
}
if "task" in robot_obs:
state_dict["task"] = robot_obs["task"]
return {**state_dict, **image_dict}
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
"""
Get a logger using the standardized logging setup from utils.py.
Args:
name: Logger name (e.g., 'policy_server', 'robot_client')
log_to_file: Whether to also log to a file
Returns:
Configured logger instance
"""
# Create logs directory if logging to file
if log_to_file:
os.makedirs("logs", exist_ok=True)
log_file = Path(f"logs/{name}_{int(time.time())}.log")
else:
log_file = None
# Initialize the standardized logging
init_logging(log_file=log_file, display_pid=False)
# Return a named logger
return logging.getLogger(name)
@dataclass
class TimedData:
"""A data object with timestamp and timestep information.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
timestep: The timestep of the data.
"""
timestamp: float
timestep: int
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
@dataclass
class TimedAction(TimedData):
action: Action
def get_action(self):
return self.action
@dataclass
class TimedObservation(TimedData):
observation: RawObservation
must_go: bool = False
def get_observation(self):
return self.observation
@dataclass
class FPSTracker:
"""Utility class to track FPS metrics over time."""
target_fps: float
first_timestamp: float = None
total_obs_count: int = 0
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
"""Calculate average FPS vs target"""
self.total_obs_count += 1
# Initialize first observation time
if self.first_timestamp is None:
self.first_timestamp = current_timestamp
# Calculate overall average FPS (since start)
total_duration = current_timestamp - self.first_timestamp
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
def reset(self):
"""Reset the FPS tracker state"""
self.first_timestamp = None
self.total_obs_count = 0
@dataclass
class RemotePolicyConfig:
policy_type: str
pretrained_name_or_path: str
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
"""Check if two observation states are similar, under a tolerance threshold"""
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
def observations_similar(
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
) -> bool:
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
observations as the difference in joint-space between the two observations.
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
to surpass this joint-space similarity check.
"""
obs1_state = extract_state_from_raw_observation(
make_lerobot_observation(obs1.get_observation(), lerobot_features)
)
obs2_state = extract_state_from_raw_observation(
make_lerobot_observation(obs2.get_observation(), lerobot_features)
)
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
def send_bytes_in_chunks(
buffer: bytes,
message_class: Any,
log_prefix: str = "",
silent: bool = True,
chunk_size: int = 3 * 1024 * 1024,
):
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
# chunk size as I am using it to send image observations.
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + chunk_size >= size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
): # type: ignore
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
bytes_buffer = io.BytesIO()
step = 0
logger.info(f"{log_prefix} Starting receiver")
for item in iterator:
logger.debug(f"{log_prefix} Received item")
if not continue_receiving.is_set():
logger.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step 0")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logger.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
complete_bytes = bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
logger.debug(f"{log_prefix} Queue updated")
return complete_bytes
else:
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
raise ValueError(f"Received unknown transfer state {item.transfer_state}")

View File

@@ -0,0 +1,403 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
```shell
python src/lerobot/scripts/server/policy_server.py \
--host=127.0.0.1 \
--port=8080 \
--fps=30 \
--inference_latency=0.033 \
--obs_queue_timeout=1
```
"""
import logging
import pickle # nosec
import threading
import time
from concurrent import futures
from dataclasses import asdict
from pprint import pformat
from queue import Empty, Queue
import draccus
import grpc
import torch
from lerobot.policies.factory import get_policy_class
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
from lerobot.scripts.server.helpers import (
FPSTracker,
Observation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
get_logger,
observations_similar,
raw_observation_to_observation,
receive_bytes_in_chunks,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server"
logger = get_logger(prefix)
def __init__(self, config: PolicyServerConfig):
self.config = config
self._running_event = threading.Event()
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=config.fps)
self.observation_queue = Queue(maxsize=1)
self._predicted_timesteps_lock = threading.Lock()
self._predicted_timesteps = set()
self.last_processed_obs = None
# Attributes will be set by SendPolicyInstructions
self.device = None
self.policy_type = None
self.lerobot_features = None
self.actions_per_chunk = None
self.policy = None
@property
def running(self):
return self._running_event.is_set()
@property
def policy_image_features(self):
return self.policy.config.image_features
def _reset_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self._running_event.clear()
self.observation_queue = Queue(maxsize=1)
with self._predicted_timesteps_lock:
self._predicted_timesteps = set()
def Ready(self, request, context): # noqa: N802
client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready")
self._reset_server()
self._running_event.set()
return async_inference_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client"""
if not self.running:
self.logger.warning("Server is not running. Ignoring policy instructions.")
return async_inference_pb2.Empty()
client_id = context.peer()
policy_specs = pickle.loads(request.data) # nosec
if not isinstance(policy_specs, RemotePolicyConfig):
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
if policy_specs.policy_type not in SUPPORTED_POLICIES:
raise ValueError(
f"Policy type {policy_specs.policy_type} not supported. "
f"Supported policies: {SUPPORTED_POLICIES}"
)
self.logger.info(
f"Receiving policy instructions from {client_id} | "
f"Policy type: {policy_specs.policy_type} | "
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
f"Device: {policy_specs.device}"
)
self.device = policy_specs.device
self.policy_type = policy_specs.policy_type # act, pi0, etc.
self.lerobot_features = policy_specs.lerobot_features
self.actions_per_chunk = policy_specs.actions_per_chunk
policy_class = get_policy_class(self.policy_type)
start = time.perf_counter()
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
self.policy.to(self.device)
end = time.perf_counter()
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
client_id = context.peer()
self.logger.debug(f"Receiving observations from {client_id}")
receive_time = time.time() # comparing timestamps so need time.time()
start_deserialize = time.perf_counter()
received_bytes = receive_bytes_in_chunks(
request_iterator, self._running_event, self.logger
) # blocking call while looping over request_iterator
timed_observation = pickle.loads(received_bytes) # nosec
deserialize_time = time.perf_counter() - start_deserialize
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
obs_timestep = timed_observation.get_timestep()
obs_timestamp = timed_observation.get_timestamp()
# Calculate FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
self.logger.info(
f"Received observation #{obs_timestep} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
f"Target: {fps_metrics['target_fps']:.2f} | "
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
)
self.logger.debug(
f"Server timestamp: {receive_time:.6f} | "
f"Client timestamp: {obs_timestamp:.6f} | "
f"Deserialization time: {deserialize_time:.6f}s"
)
if not self._enqueue_observation(
timed_observation # wrapping a RawObservation
):
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
return async_inference_pb2.Empty()
def GetActions(self, request, context): # noqa: N802
"""Returns actions to the robot client. Actions are sent as a single
chunk, containing multiple actions."""
client_id = context.peer()
self.logger.debug(f"Client {client_id} connected for action streaming")
# Generate action based on the most recent observation and its timestep
try:
getactions_starts = time.perf_counter()
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
self.logger.info(
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
)
with self._predicted_timesteps_lock:
self._predicted_timesteps.add(obs.get_timestep())
start_time = time.perf_counter()
action_chunk = self._predict_action_chunk(obs)
inference_time = time.perf_counter() - start_time
start_time = time.perf_counter()
actions_bytes = pickle.dumps(action_chunk) # nosec
serialize_time = time.perf_counter() - start_time
# Create and return the action chunk
actions = async_inference_pb2.Actions(data=actions_bytes)
self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | "
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
)
self.logger.debug(
f"Action chunk #{obs.get_timestep()} generated | "
f"Inference time: {inference_time:.2f}s |"
f"Serialize time: {serialize_time:.2f}s |"
f"Total time: {inference_time + serialize_time:.2f}s"
)
time.sleep(
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
) # sleep controls inference latency
return actions
except Empty: # no observation added to queue in obs_queue_timeout
return async_inference_pb2.Empty()
except Exception as e:
self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty()
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
"""Check if the observation is valid to be processed by the policy"""
with self._predicted_timesteps_lock:
predicted_timesteps = self._predicted_timesteps
if obs.get_timestep() in predicted_timesteps:
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
return False
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
self.logger.debug(
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
)
return False
else:
return True
def _enqueue_observation(self, obs: TimedObservation) -> bool:
"""Enqueue an observation if it must go through processing, otherwise skip it.
Observations not in queue are never run through the policy network"""
if (
obs.must_go
or self.last_processed_obs is None
or self._obs_sanity_checks(obs, self.last_processed_obs)
):
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
self.logger.debug(
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
)
# If queue is full, get the old observation to make room
if self.observation_queue.full():
# pops from queue
_ = self.observation_queue.get_nowait()
self.logger.debug("Observation queue was full, removed oldest observation")
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(obs)
return True
return False
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
for i, action in enumerate(action_chunk)
]
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
"""
Prepare observation, ready for policy inference.
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
client and then convert them to float32 [0,1] images here, before running inference.
"""
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
observation: Observation = raw_observation_to_observation(
observation_t.get_observation(),
self.lerobot_features,
self.policy_image_features,
self.device,
)
# processed Observation - right keys, right dtype, right image shape
return observation
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Get an action chunk from the policy. The chunk contains only"""
chunk = self.policy.predict_action_chunk(observation)
if chunk.ndim != 3:
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
return chunk[:, : self.actions_per_chunk, :]
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action chunk based on an observation"""
inference_starts = time.perf_counter()
"""1. Prepare observation"""
start_time = time.perf_counter()
observation = self._prepare_observation(observation_t)
preprocessing_time = time.perf_counter() - start_time
self.last_processed_obs: TimedObservation = observation_t
"""2. Get action chunk"""
start_time = time.perf_counter()
action_tensor = self._get_action_chunk(observation)
inference_time = time.perf_counter() - start_time
"""3. Post-inference processing"""
start_time = time.perf_counter()
# Move to CPU before serializing
action_tensor = action_tensor.cpu().squeeze(0)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
postprocessing_time = time.perf_counter() - start_time
inference_stops = time.perf_counter()
self.logger.info(
f"Observation {observation_t.get_timestep()} |"
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
)
# full-process latency breakdown for debugging purposes
self.logger.debug(
f"Observation {observation_t.get_timestep()} | "
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
)
return action_chunk
def stop(self):
"""Stop the server"""
self._reset_server()
self.logger.info("Server stopping...")
@draccus.wrap()
def serve(cfg: PolicyServerConfig):
"""Start the PolicyServer with the given configuration.
Args:
config: PolicyServerConfig instance. If None, uses default configuration.
"""
logging.info(pformat(asdict(cfg)))
# Create the server instance first
policy_server = PolicyServer(cfg)
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
server.start()
server.wait_for_termination()
policy_server.logger.info("Server terminated")
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,509 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example command:
```shell
python src/lerobot/scripts/server/robot_client.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--robot.id=black \
--task="dummy" \
--server_address=127.0.0.1:8080 \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
--debug_visualize_queue_size=True
```
"""
import logging
import pickle # nosec
import threading
import time
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any, Callable, Optional
import draccus
import grpc
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs.policies import PreTrainedConfig
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
)
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
from lerobot.scripts.server.helpers import (
Action,
FPSTracker,
Observation,
RawObservation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
send_bytes_in_chunks,
validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
class RobotClient:
prefix = "robot_client"
logger = get_logger(prefix)
def __init__(self, config: RobotClientConfig):
"""Initialize RobotClient with unified configuration.
Args:
config: RobotClientConfig containing all configuration parameters
"""
# Store configuration
self.config = config
self.robot = make_robot_from_config(config.robot)
self.robot.connect()
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
if config.verify_robot_cameras:
# Load policy config for validation
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
policy_image_features = policy_config.image_features
# The cameras specified for inference must match the one supported by the policy chosen
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
# Use environment variable if server_address is not provided in config
self.server_address = config.server_address
self.policy_config = RemotePolicyConfig(
config.policy_type,
config.pretrained_name_or_path,
lerobot_features,
config.actions_per_chunk,
config.policy_device,
)
self.channel = grpc.insecure_channel(self.server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self._running_event = threading.Event()
# Initialize client side variables
self.latest_action_lock = threading.Lock()
self.latest_action = -1
self.action_chunk_size = -1
self._chunk_size_threshold = config.chunk_size_threshold
self.action_queue = Queue()
self.action_queue_lock = threading.Lock() # Protect queue operations
self.action_queue_size = []
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
self.logger.info("Robot connected and ready")
# Use an event for thread-safe coordination
self.must_go = threading.Event()
self.must_go.set() # Initially set - observations qualify for direct processing
@property
def running(self):
return self._running_event.is_set()
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.perf_counter()
self.stub.Ready(async_inference_pb2.Empty())
end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server")
self.logger.debug(
f"Policy type: {self.policy_config.policy_type} | "
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
f"Device: {self.policy_config.device}"
)
self.stub.SendPolicyInstructions(policy_setup)
self._running_event.set()
return True
except grpc.RpcError as e:
self.logger.error(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self._running_event.clear()
self.robot.disconnect()
self.logger.debug("Robot disconnected")
self.channel.close()
self.logger.debug("Client stopped, channel closed")
def send_observation(
self,
obs: TimedObservation,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
if not isinstance(obs, TimedObservation):
raise ValueError("Input observation needs to be a TimedObservation!")
start_time = time.perf_counter()
observation_bytes = pickle.dumps(obs)
serialize_time = time.perf_counter() - start_time
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
try:
observation_iterator = send_bytes_in_chunks(
observation_bytes,
async_inference_pb2.Observation,
log_prefix="[CLIENT] Observation",
silent=True,
)
_ = self.stub.SendObservations(observation_iterator)
obs_timestep = obs.get_timestep()
self.logger.info(f"Sent observation #{obs_timestep} | ")
return True
except grpc.RpcError as e:
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
return False
def _inspect_action_queue(self):
with self.action_queue_lock:
queue_size = self.action_queue.qsize()
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
return queue_size, timestamps
def _aggregate_action_queues(
self,
incoming_actions: list[TimedAction],
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
if aggregate_fn is None:
# default aggregate function: take the latest action
def aggregate_fn(x1, x2):
return x2
future_action_queue = Queue()
with self.action_queue_lock:
internal_queue = self.action_queue.queue
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
for new_action in incoming_actions:
with self.latest_action_lock:
latest_action = self.latest_action
# New action is older than the latest action in the queue, skip it
if new_action.get_timestep() <= latest_action:
continue
# If the new action's timestep is not in the current action queue, add it directly
elif new_action.get_timestep() not in current_action_queue:
future_action_queue.put(new_action)
continue
# If the new action's timestep is in the current action queue, aggregate it
# TODO: There is probably a way to do this with broadcasting of the two action tensors
future_action_queue.put(
TimedAction(
timestamp=new_action.get_timestamp(),
timestep=new_action.get_timestep(),
action=aggregate_fn(
current_action_queue[new_action.get_timestep()], new_action.get_action()
),
)
)
with self.action_queue_lock:
self.action_queue = future_action_queue
def receive_actions(self, verbose: bool = False):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call
receive_time = time.time()
# Deserialize bytes back into list[TimedAction]
deserialize_start = time.perf_counter()
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
if len(timed_actions) > 0 and verbose:
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Log incoming actions
incoming_timesteps = [a.get_timestep() for a in timed_actions]
first_action_timestep = timed_actions[0].get_timestep()
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
self.logger.info(
f"Received action chunk for step #{first_action_timestep} | "
f"Latest action: #{latest_action} | "
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
)
# Update action queue
start_time = time.perf_counter()
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
queue_update_time = time.perf_counter() - start_time
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
if verbose:
# Get queue state after changes
new_size, new_timesteps = self._inspect_action_queue()
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.info(
f"Latest action: {latest_action} | "
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
)
self.logger.debug(
f"Queue update complete ({queue_update_time:.6f}s) | "
f"Before: {old_size} items | "
f"After: {new_size} items | "
)
except grpc.RpcError as e:
self.logger.error(f"Error receiving actions: {e}")
def actions_available(self):
"""Check if there are actions available in the queue"""
with self.action_queue_lock:
return not self.action_queue.empty()
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
"""Reading and performing actions in local queue"""
# Lock only for queue operations
get_start = time.perf_counter()
with self.action_queue_lock:
self.action_queue_size.append(self.action_queue.qsize())
# Get action from queue
timed_action = self.action_queue.get_nowait()
get_end = time.perf_counter() - get_start
_performed_action = self.robot.send_action(
self._action_tensor_to_action_dict(timed_action.get_action())
)
with self.latest_action_lock:
self.latest_action = timed_action.get_timestep()
if verbose:
with self.action_queue_lock:
current_queue_size = self.action_queue.qsize()
self.logger.debug(
f"Ts={timed_action.get_timestamp()} | "
f"Action #{timed_action.get_timestep()} performed | "
f"Queue size: {current_queue_size}"
)
self.logger.debug(
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
)
return _performed_action
def _ready_to_send_observation(self):
"""Flags when the client is ready to send an observation"""
with self.action_queue_lock:
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
try:
# Get serialized observation bytes from the function
start_time = time.perf_counter()
raw_observation: RawObservation = self.robot.get_observation()
raw_observation["task"] = task
with self.latest_action_lock:
latest_action = self.latest_action
observation = TimedObservation(
timestamp=time.time(), # need time.time() to compare timestamps across client and server
observation=raw_observation,
timestep=max(latest_action, 0),
)
obs_capture_time = time.perf_counter() - start_time
# If there are no actions left in the queue, the observation must go through processing!
with self.action_queue_lock:
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
current_queue_size = self.action_queue.qsize()
_ = self.send_observation(observation)
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
if observation.must_go:
# must-go event will be set again after receiving actions
self.must_go.clear()
if verbose:
# Calculate comprehensive FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
self.logger.info(
f"Obs #{observation.get_timestep()} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
f"Target: {fps_metrics['target_fps']:.2f}"
)
self.logger.debug(
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
)
return raw_observation
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
"""Combined function for executing actions and streaming observations"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Control loop thread starting")
_performed_action = None
_captured_observation = None
while self.running:
control_loop_start = time.perf_counter()
"""Control loop: (1) Performing actions, when available"""
if self.actions_available():
_performed_action = self.control_loop_action(verbose)
"""Control loop: (2) Streaming observations to the remote policy server"""
if self._ready_to_send_observation():
_captured_observation = self.control_loop_observation(task, verbose)
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
# Dynamically adjust sleep time to maintain the desired control frequency
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
return _captured_observation, _performed_action
@draccus.wrap()
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
if client.start():
client.logger.info("Starting action receiver thread...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
# Start action receiver thread
action_receiver_thread.start()
try:
# The main thread runs the control loop
client.control_loop(task=cfg.task)
finally:
client.stop()
action_receiver_thread.join()
if cfg.debug_visualize_queue_size:
visualize_action_queue_size(client.action_queue_size)
client.logger.info("Client stopped")
if __name__ == "__main__":
async_client() # run the client

View File

@@ -16,7 +16,6 @@
import logging
import time
from contextlib import nullcontext
from functools import partial
from pprint import pformat
from typing import Any
@@ -30,7 +29,6 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.datasets.utils_must import multidataset_collate_fn
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
@@ -175,23 +173,14 @@ def train(cfg: TrainPipelineConfig):
else:
shuffle = True
sampler = None
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
keys_to_max_dim = {
"action": (32,),
"observation.state": (32,),
"observation.image": (3, 1080, 1920),
"observation.image2": (3, 1080, 1920),
}
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=collate_fn,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
pin_memory=device.type == "cuda",
drop_last=False,
)
dl_iter = cycle(dataloader)
@@ -218,7 +207,7 @@ def train(cfg: TrainPipelineConfig):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(
train_tracker,

View File

@@ -35,6 +35,7 @@ from lerobot.robots import ( # noqa: F401
make_robot_from_config,
so100_follower,
so101_follower,
so101_follower_torque,
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
@@ -52,6 +53,7 @@ COMPATIBLE_DEVICES = [
"so101_follower",
"so101_leader",
"lekiwi",
"so101_follower_t",
]

View File

@@ -43,6 +43,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
hope_jr,
koch_follower,
make_robot_from_config,
so100_follower,
@@ -52,6 +53,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
gamepad,
homunculus,
koch_leader,
make_teleoperator_from_config,
so100_leader,

View File

@@ -0,0 +1,4 @@
from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig
from .homunculus_arm import HomunculusArm
from .homunculus_glove import HomunculusGlove
from .joints_translation import homunculus_glove_to_hope_jr_hand

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("homunculus_glove")
@dataclass
class HomunculusGloveConfig(TeleoperatorConfig):
port: str # Port to connect to the glove
side: str # "left" / "right"
baud_rate: int = 115_200
def __post_init__(self):
if self.side not in ["right", "left"]:
raise ValueError(self.side)
@TeleoperatorConfig.register_subclass("homunculus_arm")
@dataclass
class HomunculusArmConfig(TeleoperatorConfig):
port: str # Port to connect to the arm
baud_rate: int = 115_200

View File

@@ -0,0 +1,310 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque, Dict, Optional
import serial
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..teleoperator import Teleoperator
from .config_homunculus import HomunculusArmConfig
logger = logging.getLogger(__name__)
class HomunculusArm(Teleoperator):
"""
Homunculus Arm designed by Hugging Face.
"""
config_class = HomunculusArmConfig
name = "homunculus_arm"
def __init__(self, config: HomunculusArmConfig):
super().__init__(config)
self.config = config
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
self.serial_lock = threading.Lock()
self.joints = {
"shoulder_pitch": MotorNormMode.RANGE_M100_100,
"shoulder_yaw": MotorNormMode.RANGE_M100_100,
"shoulder_roll": MotorNormMode.RANGE_M100_100,
"elbow_flex": MotorNormMode.RANGE_M100_100,
"wrist_roll": MotorNormMode.RANGE_M100_100,
"wrist_yaw": MotorNormMode.RANGE_M100_100,
"wrist_pitch": MotorNormMode.RANGE_M100_100,
}
n = 50
# EMA parameters ---------------------------------------------------
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: Dict[str, Deque[int]] = {
joint: deque(maxlen=n)
for joint in (
"shoulder_pitch",
"shoulder_yaw",
"shoulder_roll",
"elbow_flex",
"wrist_roll",
"wrist_yaw",
"wrist_pitch",
)
}
# running EMA value per joint lazily initialised on first read
self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers)
self._state: dict[str, float] | None = None
self.new_state_event = threading.Event()
self.stop_event = threading.Event()
self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
self.state_lock = threading.Lock()
@property
def action_features(self) -> dict:
return {f"{joint}.pos": float for joint in self.joints}
@property
def feedback_features(self) -> dict:
return {}
@property
def is_connected(self) -> bool:
with self.serial_lock:
return self.serial.is_open and self.thread.is_alive()
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
if not self.serial.is_open:
self.serial.open()
self.thread.start()
# wait for the thread to ramp up & 1st state to be ready
if not self.new_state_event.wait(timeout=2):
raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
if not self.is_calibrated and calibrate:
self.calibrate()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.calibration_fpath.is_file()
def calibrate(self) -> None:
print(
"\nMove all joints through their entire range of motion."
"\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self._record_ranges_of_motion()
self.calibration = {}
for id_, joint in enumerate(self.joints):
self.calibration[joint] = MotorCalibration(
id=id_,
drive_mode=0,
homing_offset=0,
range_min=range_mins[joint],
range_max=range_maxes[joint],
)
self._save_calibration()
print("Calibration saved to", self.calibration_fpath)
# TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code.
def _record_ranges_of_motion(
self, joints: list[str] | None = None, display_values: bool = True
) -> tuple[dict[str, int], dict[str, int]]:
"""Interactively record the min/max encoder values of each joint.
Move the joints while the method streams live positions. Press :kbd:`Enter` to finish.
Args:
joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`).
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Raises:
TypeError: `joints` is not `None` or a list.
ValueError: any joint's recorded min and max are the same.
Returns:
tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values
observed for each joint.
"""
if joints is None:
joints = list(self.joints)
elif not isinstance(joints, list):
raise TypeError(joints)
display_len = max(len(key) for key in joints)
start_positions = self._read(joints, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self._read(joints, normalize=False)
mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()}
maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for joint in joints:
print(
f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}"
)
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(joints) + 3)
same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]]
if same_min_max:
raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}")
return mins, maxes
def configure(self) -> None:
pass
# TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code.
def _normalize(self, values: dict[str, int]) -> dict[str, float]:
if not self.calibration:
raise RuntimeError(f"{self} has no calibration registered.")
normalized_values = {}
for joint, val in values.items():
min_ = self.calibration[joint].range_min
max_ = self.calibration[joint].range_max
drive_mode = self.calibration[joint].drive_mode
bounded_val = min(max_, max(min_, val))
if self.joints[joint] is MotorNormMode.RANGE_M100_100:
norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
normalized_values[joint] = -norm if drive_mode else norm
elif self.joints[joint] is MotorNormMode.RANGE_0_100:
norm = ((bounded_val - min_) / (max_ - min_)) * 100
normalized_values[joint] = 100 - norm if drive_mode else norm
return normalized_values
def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, float]:
"""Update buffers & running EMA values; return smoothed dict."""
smoothed: Dict[str, float] = {}
for joint, value in raw.items():
# maintain raw history
self._buffers[joint].append(value)
# initialise on first run
if self._ema[joint] is None:
self._ema[joint] = float(value)
else:
self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint]
smoothed[joint] = self._ema[joint]
return smoothed
def _read(
self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
) -> dict[str, int | float]:
"""
Return the most recent (single) values from self.last_d,
optionally applying calibration.
"""
if not self.new_state_event.wait(timeout=timeout):
raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
with self.state_lock:
state = self._state
self.new_state_event.clear()
if state is None:
raise RuntimeError(f"{self} Internal error: Event set but no state available.")
if joints is not None:
state = {k: v for k, v in state.items() if k in joints}
if normalize:
state = self._normalize(state)
state = self._apply_ema(state)
return state
def _read_loop(self):
"""
Continuously read from the serial buffer in its own thread and sends values to the main thread through
a queue.
"""
while not self.stop_event.is_set():
try:
raw_values = None
with self.serial_lock:
if self.serial.in_waiting > 0:
self.serial.flush()
raw_values = self.serial.readline().decode("utf-8").strip().split(" ")
if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values
continue
joint_angles = {
"shoulder_pitch": int(raw_values[19]),
"shoulder_yaw": int(raw_values[18]),
"shoulder_roll": int(raw_values[20]),
"elbow_flex": int(raw_values[17]),
"wrist_roll": int(raw_values[16]),
"wrist_yaw": int(raw_values[1]),
"wrist_pitch": int(raw_values[0]),
}
with self.state_lock:
self._state = joint_angles
self.new_state_event.set()
except Exception as e:
logger.debug(f"Error reading frame in background thread for {self}: {e}")
def get_action(self) -> dict[str, float]:
joint_positions = self._read()
return {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
def disconnect(self) -> None:
if not self.is_connected:
DeviceNotConnectedError(f"{self} is not connected.")
self.stop_event.set()
self.thread.join(timeout=1)
self.serial.close()
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,338 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque, Dict, Optional
import serial
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import MotorCalibration
from lerobot.motors.motors_bus import MotorNormMode
from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..teleoperator import Teleoperator
from .config_homunculus import HomunculusGloveConfig
logger = logging.getLogger(__name__)
LEFT_HAND_INVERSIONS = [
"thumb_cmc",
"index_dip",
"middle_mcp_abduction",
"middle_dip",
"pinky_mcp_abduction",
"pinky_dip",
]
RIGHT_HAND_INVERSIONS = [
"thumb_mcp",
"thumb_cmc",
"thumb_pip",
"thumb_dip",
"index_mcp_abduction",
# "index_dip",
"middle_mcp_abduction",
# "middle_dip",
"ring_mcp_abduction",
"ring_mcp_flexion",
# "ring_dip",
"pinky_mcp_abduction",
]
class HomunculusGlove(Teleoperator):
"""
Homunculus Glove designed by NepYope & Hugging Face.
"""
config_class = HomunculusGloveConfig
name = "homunculus_glove"
def __init__(self, config: HomunculusGloveConfig):
super().__init__(config)
self.config = config
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
self.serial_lock = threading.Lock()
self.joints = {
"thumb_cmc": MotorNormMode.RANGE_0_100,
"thumb_mcp": MotorNormMode.RANGE_0_100,
"thumb_pip": MotorNormMode.RANGE_0_100,
"thumb_dip": MotorNormMode.RANGE_0_100,
"index_mcp_abduction": MotorNormMode.RANGE_M100_100,
"index_mcp_flexion": MotorNormMode.RANGE_0_100,
"index_dip": MotorNormMode.RANGE_0_100,
"middle_mcp_abduction": MotorNormMode.RANGE_M100_100,
"middle_mcp_flexion": MotorNormMode.RANGE_0_100,
"middle_dip": MotorNormMode.RANGE_0_100,
"ring_mcp_abduction": MotorNormMode.RANGE_M100_100,
"ring_mcp_flexion": MotorNormMode.RANGE_0_100,
"ring_dip": MotorNormMode.RANGE_0_100,
"pinky_mcp_abduction": MotorNormMode.RANGE_M100_100,
"pinky_mcp_flexion": MotorNormMode.RANGE_0_100,
"pinky_dip": MotorNormMode.RANGE_0_100,
}
self.inverted_joints = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS
n = 10
# EMA parameters ---------------------------------------------------
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: Dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
# running EMA value per joint lazily initialised on first read
self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers)
self._state: dict[str, float] | None = None
self.new_state_event = threading.Event()
self.stop_event = threading.Event()
self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
self.state_lock = threading.Lock()
@property
def action_features(self) -> dict:
return {f"{joint}.pos": float for joint in self.joints}
@property
def feedback_features(self) -> dict:
return {}
@property
def is_connected(self) -> bool:
with self.serial_lock:
return self.serial.is_open and self.thread.is_alive()
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
if not self.serial.is_open:
self.serial.open()
self.thread.start()
# wait for the thread to ramp up & 1st state to be ready
if not self.new_state_event.wait(timeout=2):
raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
if not self.is_calibrated and calibrate:
self.calibrate()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.calibration_fpath.is_file()
def calibrate(self) -> None:
range_mins, range_maxes = {}, {}
for finger in ["thumb", "index", "middle", "ring", "pinky"]:
print(
f"\nMove {finger} through its entire range of motion."
"\nRecording positions. Press ENTER to stop..."
)
finger_joints = [joint for joint in self.joints if joint.startswith(finger)]
finger_mins, finger_maxes = self._record_ranges_of_motion(finger_joints)
range_mins.update(finger_mins)
range_maxes.update(finger_maxes)
self.calibration = {}
for id_, joint in enumerate(self.joints):
self.calibration[joint] = MotorCalibration(
id=id_,
drive_mode=1 if joint in self.inverted_joints else 0,
homing_offset=0,
range_min=range_mins[joint],
range_max=range_maxes[joint],
)
self._save_calibration()
print("Calibration saved to", self.calibration_fpath)
# TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code.
def _record_ranges_of_motion(
self, joints: list[str] | None = None, display_values: bool = True
) -> tuple[dict[str, int], dict[str, int]]:
"""Interactively record the min/max encoder values of each joint.
Move the joints while the method streams live positions. Press :kbd:`Enter` to finish.
Args:
joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`).
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Raises:
TypeError: `joints` is not `None` or a list.
ValueError: any joint's recorded min and max are the same.
Returns:
tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values
observed for each joint.
"""
if joints is None:
joints = list(self.joints)
elif not isinstance(joints, list):
raise TypeError(joints)
display_len = max(len(key) for key in joints)
start_positions = self._read(joints, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self._read(joints, normalize=False)
mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()}
maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for joint in joints:
print(
f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}"
)
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(joints) + 3)
same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]]
if same_min_max:
raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}")
return mins, maxes
def configure(self) -> None:
pass
# TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code.
def _normalize(self, values: dict[str, int]) -> dict[str, float]:
if not self.calibration:
raise RuntimeError(f"{self} has no calibration registered.")
normalized_values = {}
for joint, val in values.items():
min_ = self.calibration[joint].range_min
max_ = self.calibration[joint].range_max
drive_mode = self.calibration[joint].drive_mode
bounded_val = min(max_, max(min_, val))
if self.joints[joint] is MotorNormMode.RANGE_M100_100:
norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
normalized_values[joint] = -norm if drive_mode else norm
elif self.joints[joint] is MotorNormMode.RANGE_0_100:
norm = ((bounded_val - min_) / (max_ - min_)) * 100
normalized_values[joint] = 100 - norm if drive_mode else norm
return normalized_values
def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, int]:
"""Update buffers & running EMA values; return smoothed dict as integers."""
smoothed: Dict[str, int] = {}
for joint, value in raw.items():
# maintain raw history
self._buffers[joint].append(value)
# initialise on first run
if self._ema[joint] is None:
self._ema[joint] = float(value)
else:
self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint]
# Convert back to int for compatibility with normalization
smoothed[joint] = int(round(self._ema[joint]))
return smoothed
def _read(
self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
) -> dict[str, int | float]:
"""
Return the most recent (single) values from self.last_d,
optionally applying calibration.
"""
if not self.new_state_event.wait(timeout=timeout):
raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
with self.state_lock:
state = self._state
self.new_state_event.clear()
if state is None:
raise RuntimeError(f"{self} Internal error: Event set but no state available.")
if joints is not None:
state = {k: v for k, v in state.items() if k in joints}
# Apply EMA smoothing to raw values first
state = self._apply_ema(state)
# Then normalize if requested
if normalize:
state = self._normalize(state)
return state
def _read_loop(self):
"""
Continuously read from the serial buffer in its own thread and sends values to the main thread through
a queue.
"""
while not self.stop_event.is_set():
try:
positions = None
with self.serial_lock:
if self.serial.in_waiting > 0:
self.serial.flush()
positions = self.serial.readline().decode("utf-8").strip().split(" ")
if positions is None or len(positions) != len(self.joints):
continue
joint_positions = {joint: int(pos) for joint, pos in zip(self.joints, positions, strict=True)}
with self.state_lock:
self._state = joint_positions
self.new_state_event.set()
except Exception as e:
logger.debug(f"Error reading frame in background thread for {self}: {e}")
def get_action(self) -> dict[str, float]:
joint_positions = self._read()
return homunculus_glove_to_hope_jr_hand(
{f"{joint}.pos": pos for joint, pos in joint_positions.items()}
)
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
def disconnect(self) -> None:
if not self.is_connected:
DeviceNotConnectedError(f"{self} is not connected.")
self.stop_event.set()
self.thread.join(timeout=1)
self.serial.close()
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,63 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INDEX_SPLAY = 0.3
MIDDLE_SPLAY = 0.3
RING_SPLAY = 0.3
PINKY_SPLAY = 0.5
def get_ulnar_flexion(flexion: float, abduction: float, splay: float):
return -abduction * splay + flexion * (1 - splay)
def get_radial_flexion(flexion: float, abduction: float, splay: float):
return abduction * splay + flexion * (1 - splay)
def homunculus_glove_to_hope_jr_hand(glove_action: dict[str, float]) -> dict[str, float]:
return {
"thumb_cmc.pos": glove_action["thumb_cmc.pos"],
"thumb_mcp.pos": glove_action["thumb_mcp.pos"],
"thumb_pip.pos": glove_action["thumb_pip.pos"],
"thumb_dip.pos": glove_action["thumb_dip.pos"],
"index_radial_flexor.pos": get_radial_flexion(
glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY
),
"index_ulnar_flexor.pos": get_ulnar_flexion(
glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY
),
"index_pip_dip.pos": glove_action["index_dip.pos"],
"middle_radial_flexor.pos": get_radial_flexion(
glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY
),
"middle_ulnar_flexor.pos": get_ulnar_flexion(
glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY
),
"middle_pip_dip.pos": glove_action["middle_dip.pos"],
"ring_radial_flexor.pos": get_radial_flexion(
glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY
),
"ring_ulnar_flexor.pos": get_ulnar_flexion(
glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY
),
"ring_pip_dip.pos": glove_action["ring_dip.pos"],
"pinky_radial_flexor.pos": get_radial_flexion(
glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY
),
"pinky_ulnar_flexor.pos": get_ulnar_flexion(
glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY
),
"pinky_pip_dip.pos": glove_action["pinky_dip.pos"],
}

View File

@@ -33,6 +33,12 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .so101_leader import SO101Leader
return SO101Leader(config)
elif config.type == "so101_follower_t":
# For bilateral teleoperation, SO101FollowerT is used as a robot, not a teleoperator
# This should be handled in the record.py file instead
raise ValueError(
"so101_follower_t should be created as a robot instance for bilateral teleoperation, not as a teleoperator"
)
elif config.type == "stretch3":
from .stretch3_gamepad import Stretch3GamePad
@@ -53,5 +59,13 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .keyboard.teleop_keyboard import KeyboardEndEffectorTeleop
return KeyboardEndEffectorTeleop(config)
elif config.type == "homunculus_glove":
from .homunculus import HomunculusGlove
return HomunculusGlove(config)
elif config.type == "homunculus_arm":
from .homunculus import HomunculusArm
return HomunculusArm(config)
else:
raise ValueError(config.type)

View File

@@ -0,0 +1,59 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
rpc Stop(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}

View File

@@ -0,0 +1,45 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=190
_globals['_TRANSFERSTATE']._serialized_end=286
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTIONS']._serialized_start=127
_globals['_ACTIONS']._serialized_end=150
_globals['_POLICYSETUP']._serialized_start=152
_globals['_POLICYSETUP']._serialized_end=179
_globals['_EMPTY']._serialized_start=181
_globals['_EMPTY']._serialized_end=188
_globals['_ASYNCINFERENCE']._serialized_start=289
_globals['_ASYNCINFERENCE']._serialized_end=638
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,277 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from lerobot.transport import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/async_inference.AsyncInference/GetActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Stop = channel.unary_unary(
'/async_inference.AsyncInference/Stop',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Stop(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Stop': grpc.unary_unary_rpc_method_handler(
servicer.Stop,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/GetActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Stop(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Stop',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -11,11 +11,11 @@
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
//
// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. src/lerobot/transport/services.proto
// python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
//
// The command should be launched from the root of the project.

View File

@@ -1,6 +1,6 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: src/lerobot/transport/services.proto
# source: lerobot/transport/services.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
@@ -14,7 +14,7 @@ _runtime_version.ValidateProtobufRuntimeVersion(
29,
0,
'',
'src/lerobot/transport/services.proto'
'lerobot/transport/services.proto'
)
# @@protoc_insertion_point(imports)
@@ -23,23 +23,23 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.lerobot.transport.services_pb2', _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=302
_globals['_TRANSFERSTATE']._serialized_end=398
_globals['_TRANSITION']._serialized_start=51
_globals['_TRANSITION']._serialized_end=127
_globals['_PARAMETERS']._serialized_start=129
_globals['_PARAMETERS']._serialized_end=205
_globals['_INTERACTIONMESSAGE']._serialized_start=207
_globals['_INTERACTIONMESSAGE']._serialized_end=291
_globals['_EMPTY']._serialized_start=293
_globals['_EMPTY']._serialized_end=300
_globals['_LEARNERSERVICE']._serialized_start=401
_globals['_LEARNERSERVICE']._serialized_end=658
_globals['_TRANSFERSTATE']._serialized_start=298
_globals['_TRANSFERSTATE']._serialized_end=394
_globals['_TRANSITION']._serialized_start=47
_globals['_TRANSITION']._serialized_end=123
_globals['_PARAMETERS']._serialized_start=125
_globals['_PARAMETERS']._serialized_end=201
_globals['_INTERACTIONMESSAGE']._serialized_start=203
_globals['_INTERACTIONMESSAGE']._serialized_end=287
_globals['_EMPTY']._serialized_start=289
_globals['_EMPTY']._serialized_end=296
_globals['_LEARNERSERVICE']._serialized_start=397
_globals['_LEARNERSERVICE']._serialized_end=654
# @@protoc_insertion_point(module_scope)

View File

@@ -3,7 +3,7 @@
import grpc
import warnings
from src.lerobot.transport import services_pb2 as src_dot_lerobot_dot_transport_dot_services__pb2
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in src/lerobot/transport/services_pb2_grpc.py depends on'
+ f' but the generated code in lerobot/transport/services_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -38,23 +38,23 @@ class LearnerServiceStub:
"""
self.StreamParameters = channel.unary_stream(
'/transport.LearnerService/StreamParameters',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
_registered_method=True)
self.SendTransitions = channel.stream_unary(
'/transport.LearnerService/SendTransitions',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractions = channel.stream_unary(
'/transport.LearnerService/SendInteractions',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/transport.LearnerService/Ready',
request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
@@ -93,23 +93,23 @@ def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString,
),
'SendTransitions': grpc.stream_unary_rpc_method_handler(
servicer.SendTransitions,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Transition.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'SendInteractions': grpc.stream_unary_rpc_method_handler(
servicer.SendInteractions,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@@ -139,8 +139,8 @@ class LearnerService:
request,
target,
'/transport.LearnerService/StreamParameters',
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Parameters.FromString,
options,
channel_credentials,
insecure,
@@ -166,8 +166,8 @@ class LearnerService:
request_iterator,
target,
'/transport.LearnerService/SendTransitions',
src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -193,8 +193,8 @@ class LearnerService:
request_iterator,
target,
'/transport.LearnerService/SendInteractions',
src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -220,8 +220,8 @@ class LearnerService:
request,
target,
'/transport.LearnerService/Ready',
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString,
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,

View File

@@ -111,35 +111,46 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def init_logging(log_file: Path | None = None, display_pid: bool = False):
def custom_format(record):
def init_logging(
log_file: Path | None = None,
display_pid: bool = False,
console_level: str = "INFO",
file_level: str = "DEBUG",
):
def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
# NOTE: Display PID is useful for multi-process logging.
if display_pid:
pid_str = f"[PID: {os.getpid()}]"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
else:
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
return message
logging.basicConfig(level=logging.INFO)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
formatter = logging.Formatter()
formatter.format = custom_format
logger = logging.getLogger()
logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
# Remove unused default handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Write logs to console
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
# Additionally write logs to file
if log_file is not None:
# Additionally write logs to file
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
file_handler.setLevel(file_level.upper())
logger.addHandler(file_handler)
def format_big_number(num, precision=0):

View File

@@ -28,19 +28,35 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
rr.spawn(memory_limit=memory_limit)
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
def log_rerun_data(observation: dict[str, Any], action: dict[str, Any]):
for obs, val in observation.items():
if isinstance(val, float):
rr.log(f"observation.{obs}", rr.Scalar(val))
elif isinstance(val, dict):
# Handle dictionary of joint values
for joint_name, joint_val in val.items():
if isinstance(joint_val, (float, int)):
rr.log(f"observation.{obs}.{joint_name}", rr.Scalar(float(joint_val)))
elif isinstance(val, np.ndarray):
if val.ndim == 1:
for i, v in enumerate(val):
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
else:
rr.log(f"observation.{obs}", rr.Image(val), static=True)
for act, val in action.items():
if isinstance(val, float):
rr.log(f"action.{act}", rr.Scalar(val))
elif isinstance(val, dict):
# Handle dictionary of joint values
for joint_name, joint_val in val.items():
if isinstance(joint_val, (float, int)):
rr.log(f"action.{act}.{joint_name}", rr.Scalar(float(joint_val)))
elif isinstance(val, np.ndarray):
for i, v in enumerate(val):
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
elif isinstance(val, list):
# Handle list of values
for i, v in enumerate(val):
if isinstance(v, (float, int)):
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))

View File

@@ -0,0 +1,177 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end test of the asynchronous inference stack (client ↔ server).
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
is to exercise the full communication loop:
1. Client sends policy specification → Server
2. Client streams observations → Server
3. Server streams action chunks → Client
4. Client executes received actions
The test succeeds if at least one action is executed and the server records at
least one predicted timestep - demonstrating that the gRPC round-trip works
end-to-end using real (but lightweight) protocol messages.
"""
from __future__ import annotations
import threading
from concurrent import futures
import pytest
import torch
# Skip entire module if grpc is not available
pytest.importorskip("grpc")
# -----------------------------------------------------------------------------
# End-to-end test
# -----------------------------------------------------------------------------
def test_async_inference_e2e(monkeypatch):
"""Tests the full asynchronous inference pipeline."""
# Import grpc-dependent modules inside the test function
import grpc
from lerobot.robots.utils import make_robot_from_config
from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig
from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features
from lerobot.scripts.server.policy_server import PolicyServer
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
from tests.mocks.mock_robot import MockRobotConfig
# Create a stub policy similar to test_policy_server.py
class MockPolicy:
"""A minimal mock for an actual policy, returning zeros."""
class _Config:
robot_type = "dummy_robot"
@property
def image_features(self):
"""Empty image features since this test doesn't use images."""
return {}
def __init__(self):
self.config = self._Config()
def to(self, *args, **kwargs):
return self
def model(self, batch):
# Return a chunk of 20 dummy actions.
batch_size = len(batch["robot_type"])
return torch.zeros(batch_size, 20, 6)
# ------------------------------------------------------------------
# 1. Create PolicyServer instance with mock policy
# ------------------------------------------------------------------
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
policy_server = PolicyServer(policy_server_config)
# Replace the real policy with our fast, deterministic stub.
policy_server.policy = MockPolicy()
policy_server.actions_per_chunk = 20
policy_server.device = "cpu"
# Set up robot config and features
robot_config = MockRobotConfig()
mock_robot = make_robot_from_config(robot_config)
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
policy_server.lerobot_features = lerobot_features
# Force server to produce deterministic action chunks in test mode
policy_server.policy_type = "act"
def _fake_get_action_chunk(_self, _obs, _type="test"):
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk
return torch.zeros(batch_size, actions_per_chunk, action_dim)
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
# Bypass potentially heavy model loading inside SendPolicyInstructions
def _fake_send_policy_instructions(self, request, context): # noqa: N802
return async_inference_pb2.Empty()
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
# Build gRPC server running a PolicyServer
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
# Use the host/port specified in the fixture's config
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
server.add_insecure_port(server_address)
server.start()
# ------------------------------------------------------------------
# 2. Create a RobotClient around the MockRobot
# ------------------------------------------------------------------
client_config = RobotClientConfig(
server_address=server_address,
robot=robot_config,
chunk_size_threshold=0.0,
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
verify_robot_cameras=False,
)
client = RobotClient(client_config)
assert client.start(), "Client failed initial handshake with the server"
# Track action chunks received without modifying RobotClient
action_chunks_received = {"count": 0}
original_aggregate = client._aggregate_action_queues
def counting_aggregate(*args, **kwargs):
action_chunks_received["count"] += 1
return original_aggregate(*args, **kwargs)
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
# Start client threads
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
action_thread.start()
control_thread.start()
# ------------------------------------------------------------------
# 3. System exchanges a few messages
# ------------------------------------------------------------------
# Wait for 5 seconds
server.wait_for_termination(timeout=5)
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
# ------------------------------------------------------------------
# 4. Stop the system
# ------------------------------------------------------------------
client.stop()
action_thread.join()
control_thread.join()
policy_server.stop()
server.stop(grace=None)

View File

@@ -0,0 +1,459 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pickle
import time
import numpy as np
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.scripts.server.helpers import (
FPSTracker,
TimedAction,
TimedObservation,
observations_similar,
prepare_image,
prepare_raw_observation,
raw_observation_to_observation,
resize_robot_observation_image,
)
# ---------------------------------------------------------------------
# FPSTracker
# ---------------------------------------------------------------------
def test_fps_tracker_first_observation():
"""First observation should initialize timestamp and return 0 FPS."""
tracker = FPSTracker(target_fps=30.0)
timestamp = 1000.0
metrics = tracker.calculate_fps_metrics(timestamp)
assert tracker.first_timestamp == timestamp
assert tracker.total_obs_count == 1
assert metrics["avg_fps"] == 0.0
assert metrics["target_fps"] == 30.0
def test_fps_tracker_single_interval():
"""Two observations 1 second apart should give 1 FPS."""
tracker = FPSTracker(target_fps=30.0)
# First observation at t=0
metrics1 = tracker.calculate_fps_metrics(0.0)
assert metrics1["avg_fps"] == 0.0
# Second observation at t=1 (1 second later)
metrics2 = tracker.calculate_fps_metrics(1.0)
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
def test_fps_tracker_multiple_intervals():
"""Multiple observations should calculate correct average FPS."""
tracker = FPSTracker(target_fps=30.0)
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
for i, ts in enumerate(timestamps):
metrics = tracker.calculate_fps_metrics(ts)
if i == 0:
assert metrics["avg_fps"] == 0.0
elif i == len(timestamps) - 1:
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
expected_fps = 2.0
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
def test_fps_tracker_irregular_intervals():
"""FPS calculation should work with irregular time intervals."""
tracker = FPSTracker(target_fps=30.0)
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
for ts in timestamps:
metrics = tracker.calculate_fps_metrics(ts)
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
expected_fps = 4.0 / 3.0
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
# ---------------------------------------------------------------------
# TimedData helpers
# ---------------------------------------------------------------------
def test_timed_action_getters():
"""TimedAction stores & returns timestamp, action tensor and timestep."""
ts = time.time()
action = torch.arange(10)
ta = TimedAction(timestamp=ts, action=action, timestep=0)
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
torch.testing.assert_close(ta.get_action(), action)
assert ta.get_timestep() == 0
def test_timed_observation_getters():
"""TimedObservation stores & returns timestamp, dict and timestep."""
ts = time.time()
obs_dict = {"observation.state": torch.ones(6)}
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert to.get_observation() is obs_dict
assert to.get_timestep() == 0
def test_timed_data_deserialization_data_getters():
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
The async-inference stack uses ``pickle.dumps`` to move these objects across
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
This test ensures that the payload keeps its content intact after
the (de)serialization round-trip.
"""
ts = time.time()
# ------------------------------------------------------------------
# TimedAction
# ------------------------------------------------------------------
original_action = torch.randn(6)
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
# Serialize → bytes → deserialize
ta_bytes = pickle.dumps(ta_in) # nosec
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
# Identity & content checks
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert ta_out.get_timestep() == 13
torch.testing.assert_close(ta_out.get_action(), original_action)
# ------------------------------------------------------------------
# TimedObservation
# ------------------------------------------------------------------
obs_dict = {"observation.state": torch.arange(4).float()}
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
to_bytes = pickle.dumps(to_in) # nosec
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert to_out.get_timestep() == 7
assert to_out.must_go is True
assert to_out.get_observation().keys() == obs_dict.keys()
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
# ---------------------------------------------------------------------
# observations_similar()
# ---------------------------------------------------------------------
def _make_obs(state: torch.Tensor) -> TimedObservation:
"""Create a TimedObservation with raw robot observation format."""
return TimedObservation(
timestamp=time.time(),
observation={
"shoulder": state[0].item() if len(state) > 0 else 0.0,
"elbow": state[1].item() if len(state) > 1 else 0.0,
"wrist": state[2].item() if len(state) > 2 else 0.0,
"gripper": state[3].item() if len(state) > 3 else 0.0,
},
timestep=0,
)
def test_observations_similar_true():
"""Distance below atol → observations considered similar."""
# Create mock lerobot features for the similarity check
lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
}
}
obs1 = _make_obs(torch.zeros(4))
obs2 = _make_obs(0.5 * torch.ones(4))
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
obs3 = _make_obs(2.0 * torch.ones(4))
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
# ---------------------------------------------------------------------
# raw_observation_to_observation and helpers
# ---------------------------------------------------------------------
def _create_mock_robot_observation():
"""Create a mock robot observation with motor positions and camera images."""
return {
"shoulder": 1.0,
"elbow": 2.0,
"wrist": 3.0,
"gripper": 0.5,
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
}
def _create_mock_lerobot_features():
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
return {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
"observation.images.phone": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
}
def _create_mock_policy_image_features():
"""Create mock policy image features with different resolutions."""
return {
"observation.images.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Policy expects smaller resolution
),
"observation.images.phone": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 160, 160), # Different resolution for second camera
),
}
def test_prepare_image():
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
# Create mock int8 image data
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
processed = prepare_image(image_int8)
# Check dtype conversion
assert processed.dtype == torch.float32
# Check normalization range
assert processed.min() >= 0.0
assert processed.max() <= 1.0
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
if image_int8.max() == 255:
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
if image_int8.min() == 0:
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
# Check memory contiguity
assert processed.is_contiguous()
def test_resize_robot_observation_image():
"""Test image resizing from robot resolution to policy resolution."""
# Create mock image: (H=480, W=640, C=3)
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
target_shape = (3, 224, 224) # (C, H, W)
resized = resize_robot_observation_image(original_image, target_shape)
# Check output shape matches target
assert resized.shape == target_shape
# Check that original image had different dimensions
assert original_image.shape != resized.shape
# Check that resizing preserves value range
assert resized.min() >= 0
assert resized.max() <= 255
def test_prepare_raw_observation():
"""Test the preparation of raw robot observation to lerobot format."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
# Check that state is properly extracted and batched
assert "observation.state" in prepared
state = prepared["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.shape == (1, 4) # Batched state
# Check that images are processed and resized
assert "observation.images.laptop" in prepared
assert "observation.images.phone" in prepared
laptop_img = prepared["observation.images.laptop"]
phone_img = prepared["observation.images.phone"]
# Check image shapes match policy requirements
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
# Check that images are tensors
assert isinstance(laptop_img, torch.Tensor)
assert isinstance(phone_img, torch.Tensor)
def test_raw_observation_to_observation_basic():
"""Test the main raw_observation_to_observation function."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all expected keys are present
assert "observation.state" in observation
assert "observation.images.laptop" in observation
assert "observation.images.phone" in observation
# Check state processing
state = observation["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.device.type == device
assert state.shape == (1, 4) # Batched
# Check image processing
laptop_img = observation["observation.images.laptop"]
phone_img = observation["observation.images.phone"]
# Images should have batch dimension: (B, C, H, W)
assert laptop_img.shape == (1, 3, 224, 224)
assert phone_img.shape == (1, 3, 160, 160)
# Check device placement
assert laptop_img.device.type == device
assert phone_img.device.type == device
# Check image dtype and range (should be float32 in [0, 1])
assert laptop_img.dtype == torch.float32
assert phone_img.dtype == torch.float32
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
def test_raw_observation_to_observation_with_non_tensor_data():
"""Test that non-tensor data (like task strings) is preserved."""
robot_obs = _create_mock_robot_observation()
robot_obs["task"] = "pick up the red cube" # Add string instruction
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that task string is preserved
assert "task" in observation
assert observation["task"] == "pick up the red cube"
assert isinstance(observation["task"], str)
@torch.no_grad()
def test_raw_observation_to_observation_device_handling():
"""Test that tensors are properly moved to the specified device."""
device = "mps" if torch.backends.mps.is_available() else "cpu"
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all tensors are on the correct device
for key, value in observation.items():
if isinstance(value, torch.Tensor):
assert value.device.type == device, f"Tensor {key} not on {device}"
def test_raw_observation_to_observation_deterministic():
"""Test that the function produces consistent results for the same input."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
# Run twice with same input
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Results should be identical
assert set(obs1.keys()) == set(obs2.keys())
for key in obs1:
if isinstance(obs1[key], torch.Tensor):
torch.testing.assert_close(obs1[key], obs2[key])
else:
assert obs1[key] == obs2[key]
def test_image_processing_pipeline_preserves_content():
"""Test that the image processing pipeline preserves recognizable patterns."""
# Create an image with a specific pattern
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
original_img[25:75, 25:75, :] = 255 # White square in center
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
"dtype": "image",
"shape": [100, 100, 3],
"names": ["height", "width", "channels"],
},
}
policy_image_features = {
"observation.images.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 50, 50), # Downsamples from 100x100
)
}
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
# Check that the center region has higher values than corners
# Due to bilinear interpolation, exact values will change but pattern should remain
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
corner_val = processed_img[:, 5, 5].mean() # Corner
assert center_val > corner_val, "Image processing should preserve recognizable patterns"

View File

@@ -0,0 +1,215 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit-tests for the `PolicyServer` core logic.
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
"""
from __future__ import annotations
import time
import pytest
import torch
from lerobot.configs.types import PolicyFeature
from tests.utils import require_package
# -----------------------------------------------------------------------------
# Test fixtures
# -----------------------------------------------------------------------------
class MockPolicy:
"""A minimal mock for an actual policy, returning zeros.
Refer to tests/policies for tests of the individual policies supported."""
class _Config:
robot_type = "dummy_robot"
@property
def image_features(self) -> dict[str, PolicyFeature]:
"""Empty image features since this test doesn't use images."""
return {}
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return a chunk of 20 dummy actions."""
batch_size = len(observation["observation.state"])
return torch.zeros(batch_size, 20, 6)
def __init__(self):
self.config = self._Config()
def to(self, *args, **kwargs):
# The server calls `policy.to(device)`. This stub ignores it.
return self
def model(self, batch: dict) -> torch.Tensor:
# Return a chunk of 20 dummy actions.
batch_size = len(batch["robot_type"])
return torch.zeros(batch_size, 20, 6)
@pytest.fixture
@require_package("grpc")
def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check)
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.policy_server import PolicyServer
test_config = PolicyServerConfig(host="localhost", port=9999)
server = PolicyServer(test_config)
# Replace the real policy with our fast, deterministic stub.
server.policy = MockPolicy()
server.actions_per_chunk = 20
server.device = "cpu"
# Add mock lerobot_features that the observation similarity functions need
server.lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [6],
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
}
}
return server
# -----------------------------------------------------------------------------
# Helper utilities for tests
# -----------------------------------------------------------------------------
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
"""Create a TimedObservation with a given state vector."""
# Import only when needed
from lerobot.scripts.server.helpers import TimedObservation
return TimedObservation(
observation={
"joint1": state[0].item() if len(state) > 0 else 0.0,
"joint2": state[1].item() if len(state) > 1 else 0.0,
"joint3": state[2].item() if len(state) > 2 else 0.0,
"joint4": state[3].item() if len(state) > 3 else 0.0,
"joint5": state[4].item() if len(state) > 4 else 0.0,
"joint6": state[5].item() if len(state) > 5 else 0.0,
},
timestamp=time.time(),
timestep=timestep,
must_go=must_go,
)
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
def test_time_action_chunk(policy_server):
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
start_ts = time.time()
start_t = 10
# A chunk of 3 action tensors.
action_tensors = [torch.randn(6) for _ in range(3)]
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
assert len(timed_actions) == 3
# Check timesteps
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
# Check timestamps
expected_timestamps = [
start_ts,
start_ts + policy_server.config.environment_dt,
start_ts + 2 * policy_server.config.environment_dt,
]
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
def test_maybe_enqueue_observation_must_go(policy_server):
"""An observation with `must_go=True` is always enqueued."""
obs = _make_obs(torch.zeros(6), must_go=True)
assert policy_server._enqueue_observation(obs) is True
assert policy_server.observation_queue.qsize() == 1
assert policy_server.observation_queue.get_nowait() is obs
def test_maybe_enqueue_observation_dissimilar(policy_server):
"""A dissimilar observation (not `must_go`) is enqueued."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, dissimilar observation.
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
assert policy_server._enqueue_observation(new_obs) is True
assert policy_server.observation_queue.qsize() == 1
def test_maybe_enqueue_observation_is_skipped(policy_server):
"""A similar observation (not `must_go`) is skipped."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, very similar observation.
new_obs = _make_obs(torch.zeros(6) + 1e-4)
assert policy_server._enqueue_observation(new_obs) is False
assert policy_server.observation_queue.empty() is True
def test_obs_sanity_checks(policy_server):
"""Unit-test the private `_obs_sanity_checks` helper."""
prev = _make_obs(torch.zeros(6), timestep=0)
# Case 1 timestep already predicted
policy_server._predicted_timesteps.add(1)
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
# Case 2 observation too similar
policy_server._predicted_timesteps.clear()
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
# Case 3 genuinely new & dissimilar observation passes
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
def test_predict_action_chunk(monkeypatch, policy_server):
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
# Import only when needed
from lerobot.scripts.server.policy_server import PolicyServer
# Force server to act-style policy; patch method to return deterministic tensor
policy_server.policy_type = "act"
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk
def _fake_get_action_chunk(_self, _obs, _type="act"):
return torch.zeros(batch_size, actions_per_chunk, action_dim)
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
obs = _make_obs(torch.zeros(6), timestep=5)
timed_actions = policy_server._predict_action_chunk(obs)
assert len(timed_actions) == actions_per_chunk
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
for i, ta in enumerate(timed_actions):
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
assert abs(ta.get_timestamp() - expected_ts) < 1e-6

View File

@@ -0,0 +1,234 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that
no real hardware is accessed. Only the queue-update mechanism is verified.
"""
from __future__ import annotations
import time
from queue import Queue
import pytest
import torch
# Skip entire module if grpc is not available
pytest.importorskip("grpc")
# -----------------------------------------------------------------------------
# Test fixtures
# -----------------------------------------------------------------------------
@pytest.fixture()
def robot_client():
"""Fresh `RobotClient` instance for each test case (no threads started).
Uses DummyRobot."""
# Import only when the test actually runs (after decorator check)
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.robot_client import RobotClient
from tests.mocks.mock_robot import MockRobotConfig
test_config = MockRobotConfig()
# gRPC channel is not actually used in tests, so using a dummy address
test_config = RobotClientConfig(
robot=test_config,
server_address="localhost:9999",
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
verify_robot_cameras=False,
)
client = RobotClient(test_config)
# Initialize attributes that are normally set in start() method
client.chunks_received = 0
client.available_actions_size = []
yield client
if client.robot.is_connected:
client.stop()
# -----------------------------------------------------------------------------
# Helper utilities for tests
# -----------------------------------------------------------------------------
def _make_actions(start_ts: float, start_t: int, count: int):
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
from lerobot.scripts.server.helpers import TimedAction
fps = 30 # emulates most common frame-rate
actions = []
for i in range(count):
timestep = start_t + i
timestamp = start_ts + i * (1 / fps)
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
return actions
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
def test_update_action_queue_discards_stale(robot_client):
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
# Pretend we already executed up to action #4
robot_client.latest_action = 4
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
robot_client._aggregate_action_queues(incoming)
# Extract timesteps from queue
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
assert resulting_timesteps == [5, 6, 7]
@pytest.mark.parametrize(
"weight_old, weight_new",
[
(1.0, 0.0),
(0.0, 1.0),
(0.5, 0.5),
(0.2, 0.8),
(0.8, 0.2),
(0.1, 0.9),
(0.9, 0.1),
],
)
def test_aggregate_action_queues_combines_actions_in_overlap(
robot_client, weight_old: float, weight_new: float
):
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
to the provided aggregate_fn, here tested with multiple coefficients."""
from lerobot.scripts.server.helpers import TimedAction
robot_client.chunks_received = 0
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
robot_client.latest_action = 4
current_actions = _make_actions(
start_ts=time.time(), start_t=5, count=2
) # actions are [torch.ones(6), torch.ones(6), ...]
current_actions = [
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
for a in current_actions
]
for a in current_actions:
robot_client.action_queue.put(a)
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
nonoverlap_timesteps = [7]
robot_client._aggregate_action_queues(
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
)
queue_overlap_actions = []
queue_non_overlap_actions = []
for a in robot_client.action_queue.queue:
if a.get_timestep() in overlap_timesteps:
queue_overlap_actions.append(a)
elif a.get_timestep() in nonoverlap_timesteps:
queue_non_overlap_actions.append(a)
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
assert torch.allclose(
queue_overlap_actions[0].get_action(),
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
)
assert torch.allclose(
queue_overlap_actions[1].get_action(),
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
)
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
@pytest.mark.parametrize(
"chunk_size, queue_len, expected",
[
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
(10, 5, True),
(10, 6, False),
],
)
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
robot_client.action_chunk_size = chunk_size
# Clear any existing actions then fill with `queue_len` dummy entries ----
robot_client.action_queue = Queue()
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
for act in dummy_actions:
robot_client.action_queue.put(act)
assert robot_client._ready_to_send_observation() is expected
@pytest.mark.parametrize(
"g_threshold, expected",
[
# The condition is `queue_size / chunk_size <= g`.
# Here, ratio = 6 / 10 = 0.6.
(0.0, False), # 0.6 <= 0.0 is False
(0.1, False),
(0.2, False),
(0.3, False),
(0.4, False),
(0.5, False),
(0.6, True), # 0.6 <= 0.6 is True
(0.7, True),
(0.8, True),
(0.9, True),
(1.0, True),
],
)
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
chunk_size = 10
queue_len = 6
robot_client.action_chunk_size = chunk_size
# This is the parameter we are testing
robot_client._chunk_size_threshold = g_threshold
# Fill queue with dummy actions
robot_client.action_queue = Queue()
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
for act in dummy_actions:
robot_client.action_queue.put(act)
assert robot_client._ready_to_send_observation() is expected

View File

@@ -394,56 +394,37 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
# @pytest.mark.skip("TODO after fix multidataset")
@pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_frames():
"""Check that all dataset frames are incorporated and aligned correctly."""
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
# dummy padding dimensions (simulate training setup)
MAX_ACTION_DIM = 14
MAX_STATE_DIM = 30
MAX_NUM_IMAGES = 3
MAX_IMAGE_DIM = 224
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(
repo_ids,
max_action_dim=MAX_ACTION_DIM,
max_state_dim=MAX_STATE_DIM,
max_num_images=MAX_NUM_IMAGES,
max_image_dim=MAX_IMAGE_DIM,
)
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_item, multi_item in zip(
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = multi_item.pop("dataset_index")
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
# we ignore padding_mask and dataset_index keys in multi_item
extra_keys = {k for k in multi_item if "padding_mask" in k}
filtered_multi_keys = set(multi_item.keys()) - extra_keys
assert set(sub_item.keys()) == filtered_multi_keys, "mismatch in keys"
for k in sub_item:
if k not in multi_item:
continue
v1, v2 = sub_item[k], multi_item[k]
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
else:
assert v1 == v2, f"value mismatch on key: {k}"
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts): Move to more appropriate location

View File

@@ -219,7 +219,7 @@ def test__write(addr, length, id_, value, mock_motors, dummy_motors):
comm, error = bus._write(addr, length, id_, value)
assert mock_motors.stubs[stub].called
assert mock_motors.stubs[stub].wait_called()
assert comm == scs.COMM_SUCCESS
assert error == 0
@@ -371,9 +371,9 @@ def test_reset_calibration(mock_motors, dummy_motors):
bus.reset_calibration()
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs)
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
assert all(mock_motors.stubs[stub].wait_called() for stub in write_mins_stubs)
assert all(mock_motors.stubs[stub].wait_called() for stub in write_maxes_stubs)
def test_set_half_turn_homings(mock_motors, dummy_motors):
@@ -410,7 +410,7 @@ def test_set_half_turn_homings(mock_motors, dummy_motors):
bus.reset_calibration.assert_called_once()
assert mock_motors.stubs[read_pos_stub].called
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
def test_record_ranges_of_motion(mock_motors, dummy_motors):

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "84d8ae1881704ebae1ffb70a",
"documentMicroversion": "0eea3500852bdb2f58b1cb79",
"documentVersion": "a5c3b0dfaa52ddd6829011cd",
"elementId": "22efbe4e0bef24fcd20f96e5",
"fullConfiguration": "default",
"id": "MCOhripg0ry51VlsC",
"isStandardContent": false,
"name": "Base_motor_holder_SO101 v1 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "bf61a6bc85b1d1a8bf9ea51b",
"documentMicroversion": "20484d37162a32a8a41a37f2",
"documentVersion": "25801b070e5b360715de8a30",
"elementId": "312f32f0073fa6e8e36fba7a",
"fullConfiguration": "default",
"id": "MY69cJlqvSzIiODdH",
"isStandardContent": false,
"name": "Base_SO101 v2 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

BIN
urdf/assets/base_so101_v2.stl LFS Normal file

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "652d5731024e57367badfda6",
"documentMicroversion": "56a8b8013480c176fd87df8d",
"documentVersion": "984ac31c92cac3664c8effb3",
"elementId": "6fb7b7f9315511b548d670ff",
"fullConfiguration": "default",
"id": "Mf4ZebMr4BkShucFj",
"isStandardContent": false,
"name": "Motor_holder_SO101_Base v1 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "4bd66da73cacb4d946d43e44",
"documentMicroversion": "2bf56247e58b70e90806e318",
"documentVersion": "df78bb7089f1de7d5588d238",
"elementId": "d7dfe76e402c21bbd8124e43",
"fullConfiguration": "default",
"id": "MN9BZ1p69dQQtKTjq",
"isStandardContent": false,
"name": "Motor_holder_SO101_Wrist v1 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "46218c02ef80d36172edbb35",
"documentMicroversion": "68b7d387e2500c451586ae59",
"documentVersion": "79c101d1a0207b77362b561a",
"elementId": "d4b1411d5d7333298f6e2458",
"fullConfiguration": "default",
"id": "MrHPLr9hZkrXwcSA4",
"isStandardContent": false,
"name": "Moving_Jaw_SO101 v1 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "14078aa6723c502d07d6902e",
"documentMicroversion": "c0fca717407275159bcc6ed7",
"documentVersion": "3d9a887ff68fa477d98162b8",
"elementId": "43d24b3857ff686b275578bf",
"fullConfiguration": "default",
"id": "MrQ6Kmk9QDZlwbp95",
"isStandardContent": false,
"name": "Rotation_Pitch_SO101 v1 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "56e5f3702dad85e17841d2e2",
"documentMicroversion": "7958a6acbc8e0d0a0a611746",
"documentVersion": "29a4c51b8bf277a22743a333",
"elementId": "8c14fb13a6557ec89ff5d227",
"fullConfiguration": "default",
"id": "MOcaIFg8XgL+Ybg9z",
"isStandardContent": false,
"name": "STS3215_03a_no_horn v1 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "d2941bdba816affebdc6d6f0",
"documentMicroversion": "5904ef3cea04a0d0bc88b698",
"documentVersion": "dd4f7470101215836a4ae8c9",
"elementId": "e670b72d49b06f88fad5dbd8",
"fullConfiguration": "default",
"id": "M5vQNpe0onRFueych",
"isStandardContent": false,
"name": "STS3215_03a v1 <5>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

BIN
urdf/assets/sts3215_03a_v1.stl LFS Normal file

Binary file not shown.

View File

@@ -0,0 +1,14 @@
{
"configuration": "default",
"documentId": "9f5d6db47eb112442b9f130f",
"documentMicroversion": "e99cf45162e34789bd99512b",
"documentVersion": "817ebf29c5663d412edc0753",
"elementId": "2813aaffe3c8a342616d3527",
"fullConfiguration": "default",
"id": "M9yAEiX02J3c4HqXa",
"isStandardContent": false,
"name": "Under_arm_SO101 v1 <1>",
"partId": "JFD",
"suppressed": false,
"type": "Part"
}

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More