Compare commits

..

140 Commits

Author SHA1 Message Date
Pepijn
7788db7838 Merge branch 'feat/add_pi' into feat/validate_pi_libero 2025-09-14 16:19:32 +02:00
Pepijn
d883c78a94 remove additional image augmentations, lerobot dataset already does this 2025-09-13 21:20:09 +02:00
Pepijn
de42da8225 Merge branch 'feat/add_pi' into feat/validate_pi_libero 2025-09-13 17:54:36 +02:00
Pepijn
d0d714be47 rename to loss 2025-09-13 16:15:29 +02:00
Pepijn
7d9b469eee fix override self.pretrained_path = None overwrite 2025-09-13 14:50:43 +02:00
Pepijn
6db39cad58 temp: hardcode base model 2025-09-13 14:43:09 +02:00
Pepijn
af0676f99e load from pretrained_path 2025-09-13 14:27:07 +02:00
Pepijn
b9df1a4ac5 use same name for action and state dim as lerobot pi0 and remove fixed image keys 2025-09-13 13:08:41 +02:00
Pepijn
5361346bec Do not add model prefix to normalization 2025-09-13 11:25:26 +02:00
Pepijn
f0b969ae48 do not rename normalization layers 2025-09-13 11:23:58 +02:00
Pepijn
a9d54cbddb Merge branch 'feat/add_pi' into feat/validate_pi_libero 2025-09-13 11:13:13 +02:00
Pepijn
c5a029a28a also compile forward method 2025-09-13 11:12:54 +02:00
Pepijn
c8163662ad add preprocess tests 2025-09-12 21:41:25 +02:00
Pepijn
376cc772ff fix from pretrained 2025-09-12 21:12:48 +02:00
Pepijn
d1eefd4e97 fix: remove unused param 2025-09-12 20:25:55 +02:00
Pepijn
7a03223693 use safeauto_docstring 2025-09-12 20:19:16 +02:00
Pepijn
f840d2e006 fix(modeling pi0): nit warning message 2025-09-12 20:06:06 +02:00
Pepijn
e94844fa59 revert to openpi transformer replace python 3.11 2025-09-12 20:00:21 +02:00
Pepijn
990f8e9cc9 update to python 3.11 2025-09-12 19:04:42 +02:00
Pepijn
6ce2a00135 also for pi05 2025-09-12 19:02:13 +02:00
Pepijn
bf90efa7e1 fix key match from pytorch state dict (similar keys to openpi implementation now) 2025-09-12 18:44:12 +02:00
Pepijn
5b4ac3068e Merge branch 'feat/add_pi' into feat/validate_pi_libero 2025-09-12 11:44:42 +02:00
Pepijn
dbe3406a69 add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0 2025-09-12 11:12:47 +02:00
Pepijn
1785767e61 clean up padding of state and action (more in line with lerobot pi0) 2025-09-12 10:38:24 +02:00
Pepijn
afd833f49e Merge branch 'feat/add_pi' into feat/validate_pi_libero 2025-09-12 09:41:13 +02:00
Pepijn
2234b851c0 rename action_horizon to chunk_size 2025-09-11 19:42:25 +02:00
Pepijn
e4a214d890 fetch 2025-09-11 17:49:36 +02:00
Pepijn
e8438aac59 Merge branch 'pr/1676' into feat/validate_pi_libero 2025-09-11 16:35:55 +02:00
pre-commit-ci[bot]
8fe977118b [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-11 12:30:09 +00:00
Jade Choghari
d09b2a28af remove 2025-09-11 14:28:46 +02:00
pre-commit-ci[bot]
f2530570e0 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-11 12:25:14 +00:00
Jade Choghari
8567ab60d8 remove unces 2025-09-11 14:24:06 +02:00
pre-commit-ci[bot]
9784123463 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-11 12:18:36 +00:00
Jade Choghari
4c2add41d7 remove files 2025-09-11 14:18:09 +02:00
pre-commit-ci[bot]
a19d7fb6bf [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-11 11:51:53 +00:00
Jade Choghari
565c992589 iterate on review 2025-09-11 13:47:58 +02:00
Jade Choghari
96cc634a66 add new changes 2025-09-11 12:21:21 +02:00
Pepijn
b044f3104b remove check 2025-09-11 11:03:41 +02:00
Pepijn
384ec52ec7 add pi05 to factory 2025-09-11 11:01:31 +02:00
Pepijn
8d1434c069 remove warning in config 2025-09-11 10:37:56 +02:00
Pepijn
f613a37cd2 add some comments, license and readme 2025-09-11 10:36:38 +02:00
Pepijn
494aa576b2 fix push to hub test 2025-09-11 09:18:20 +02:00
Pepijn
514625a7f6 fix test 2025-09-11 09:15:21 +02:00
Pepijn
9f7bfeb419 split pi0 and pi05 policy in seperate files 2025-09-11 09:04:46 +02:00
Jade Choghari
aa40c8c813 More things 2025-09-10 23:24:18 +02:00
Pepijn
d36bdac114 fix test 2025-09-10 21:58:35 +02:00
Pepijn
ff1666b216 fix transformer dependency 2025-09-10 21:57:43 +02:00
Pepijn
c57d3a9688 remove test 2025-09-10 21:54:41 +02:00
Pepijn
9ae11a087d all test pass! and fix tokenizer max length between 05 and 0 2025-09-10 21:51:40 +02:00
Pepijn
21e63b505f fix test 2025-09-10 21:41:05 +02:00
Pepijn
e9e7eb827a also shorten action_steps 2025-09-10 21:36:58 +02:00
Pepijn
ac323b0113 add pi05 2025-09-10 21:33:55 +02:00
Pepijn
b028907d21 use dummy stats 2025-09-10 20:42:48 +02:00
Pepijn
2eafcc7ca1 add model. prefix to all keys in state dict 2025-09-10 20:35:19 +02:00
Pepijn
b3b57a8288 do same in other files 2025-09-10 20:28:09 +02:00
Pepijn
eaaf1c1766 additionally 2025-09-10 20:25:46 +02:00
Pepijn
3bc3bf0391 fix autodocstring 2025-09-10 20:24:39 +02:00
Pepijn
8c5fe10d6c adhere to python 3.11 syntax 2025-09-10 20:20:31 +02:00
Pepijn
8178a06b90 do detailed import 2025-09-10 20:03:14 +02:00
Pepijn
9ea8bd029c change device in test 2025-09-10 19:50:49 +02:00
Pepijn
bd5c264c49 initial commit 2025-09-10 19:44:41 +02:00
Jade Choghari
5c628f1700 new things 2025-09-10 11:32:54 +02:00
Steven Palma
d602e8169c fix(scripts): revert deletion of rs cam config import introduced by #1767 (#1876) 2025-09-08 18:29:39 +02:00
Steven Gong
49baccdccb Disable torque before applying calibration logic (#1889) 2025-09-08 11:38:13 +02:00
Jade Choghari
9beafe0c19 quick install fix for testing 2025-09-05 14:53:55 +03:00
Jade Choghari
27c9db60a6 Merge branch 'main' into add-libero 2025-09-05 14:08:33 +03:00
Jade Choghari
fda5fb5e94 Merge branch 'add-libero' of https://github.com/jadechoghari/lerobot into add-libero 2025-09-05 13:47:58 +03:00
Jade Choghari
5f5438d6fa remove sh files 2025-09-05 13:47:23 +03:00
pre-commit-ci[bot]
2b779cd6c6 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-05 10:36:51 +00:00
Jade Choghari
3886af42a5 single line blank change 2025-09-05 13:36:27 +03:00
pre-commit-ci[bot]
38f7229078 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-05 09:55:43 +00:00
Jade
504421949c iterate on review 2025-09-05 12:54:07 +03:00
pre-commit-ci[bot]
28b9efc04f [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-05 09:23:32 +00:00
Jade Choghari
abba423e28 Update docs/source/libero.mdx
Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-09-05 12:23:22 +03:00
Jade Choghari
47a81c4150 Update docs/source/libero.mdx
Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-09-05 12:23:12 +03:00
Gaëlle Lannuzel
6a3d57031a 2 add reachy 2 to updated lerobot (#1767)
* Start adding Reachy 2 (no camera)

* Fix joint shape

* Remove print

* Modify observation_features

* Fix observation state

* Try adding a fake Reachy teleoperator

* Saving test scripts

* Add reachy2camera to cameras

* Add teleop_left camera to observation

* Create test_reachy2_camera.py

* Update utils.py

* Add all rgb cameras

* Future depth work

* Try adding mobile_base velocity

* Update tests

* Update data_acquisition_server.py

* Update with use_external_commands

* Replay

* Usable with or without mobile base

* No need for new isntance

* Use same ip for cameras

* Remove useless imports

* Add resume

* Divide joints in multiple dicts

* Divide joinits into several dicts in teleoperator

* Fix forgotten method call

* Create test_robot_client.py

* Open gripper on start

* Add arguments for cameras

* Modify get_frame() requested size

* Call generate_joints_dict on _init_

* black + isort

* Add reachy2 in imports

* Add reachy2 dependencies

* Add documentation

* Update reachy2.mdx

* Update reachy2.mdx

* Clean files and add types

* Fix type in send_action

* Remove print

* Delete test files

* Clean code

* Update cameras

* Disconnect from camera

* Run pre-commit hooks

* Update pyproject.toml

* Create test_reachy2.py

* Fix generate_joints

* Update test_reachy2.py

* Update send_action test

* Update reachy2_cameras depth + CameraManager

* Update reachy2_camera tests

* Remove useless import and args

* Rename reachy2_teleoperator

* Create test_reachy2_teleoperator.py

* Fix remainging fake_teleoperator

* Remove useless elements

* Mock cameras in test_reachy2

* Delete commented lines

* Add use_present_position to teleoperator

* Add cameras tests

* Add check no part + test

* Use disable_torque_on_disconnect

* Use odometry for vel with present_position

* Update documentation

* Fix vel value type

* Use ensure_safe_goal_position

* Import joints dict from classes

* Update reachy2.mdx

* Update reachy2.mdx

* Update minimal version

* Update minimal version

* fix(tests) fixes for reachy2 tests; removing reachy2 references from the script

* Add reachy2_sdk fake as plugins

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-09-05 11:03:14 +02:00
Justin Huang
d74494d92b Allow max_relative_target to be a float (#1837)
* Remove unused max_relative_target for stretch3

* Fix type annotation and allow integer max_relative_target values

* Configure max_relative_target to be floats instead of ints

* Update docs and types to reflect that max_relative_target can be a dict

* Remove unnecessary isinstance check for ints

* Fix typo in name

---------

Co-authored-by: Justin Huang <justin.huang@jpl.nasa.gov>
2025-09-05 09:58:47 +02:00
Jade Choghari
1ba896598e Merge branch 'train-smolvla' into add-multitraining
:wq
a
2025-09-04 14:32:06 +02:00
Jade Choghari
61e55830da add train 2025-09-04 12:12:10 +02:00
Jade Choghari
b7522da85d hotfix: flip actions 2025-09-04 10:32:06 +03:00
pre-commit-ci[bot]
98dc053e6d [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-03 15:57:04 +00:00
Jade Choghari
bbff93d20d skip test warning 2025-09-03 11:54:46 -04:00
Jade Choghari
32c1649085 update doc 2025-09-03 11:51:28 -04:00
Jade Choghari
eb564f8ddb update docs/script 2025-09-03 11:46:13 -04:00
Jade Choghari
a2958a8e0c fix docs 2025-09-03 02:55:20 -04:00
Jade Choghari
8f1679f309 remove brkpt 2025-09-02 11:00:27 -04:00
pre-commit-ci[bot]
b1473f11c8 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-02 12:12:45 +00:00
Jade Choghari
7b556079d8 update doc 2025-09-02 08:12:10 -04:00
pre-commit-ci[bot]
e91a773b93 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-02 12:10:50 +00:00
Jade Choghari
a9bd67eae9 fix 2025-09-02 08:10:00 -04:00
Jade Choghari
4a4ac759ec doc 2025-09-02 08:07:14 -04:00
pre-commit-ci[bot]
7dd8e015f8 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-02 11:33:58 +00:00
Jade Choghari
af2960c33e add docs for eval 2025-09-02 07:33:16 -04:00
Jade Choghari
a36e4619ad Merge branch 'main' into add-libero 2025-09-02 13:06:24 +03:00
pre-commit-ci[bot]
b397a757bb [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-02 09:19:57 +00:00
Jade Choghari (jchoghar)
92adf2218f iterate on review 2025-09-02 05:18:46 -04:00
Jade Choghari
f3614dd812 Delete libero-requirements.txt 2025-08-30 20:43:33 +03:00
Jade Choghari
b23b7a5bd7 improve install 2025-08-30 20:43:20 +03:00
Pepijn
882c80d446 Lower limits by 50% for current and torque for gripper motor (#1809)
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-08-29 16:06:55 +02:00
Jade Choghari
6ff5f318b2 cleanup 2 2025-08-29 10:22:29 +03:00
pre-commit-ci[bot]
2eae751977 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-29 07:20:21 +00:00
Jade Choghari
894878039d Merge branch 'add-libero' of github.com:jadechoghari/lerobot into add-libero 2025-08-29 10:19:23 +03:00
Jade Choghari
ab72471dda clean 2025-08-29 10:19:01 +03:00
pre-commit-ci[bot]
23849e0cb8 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-28 19:50:21 +00:00
Jade Choghari
cb18fc07ef cleanup (#5) 2025-08-28 22:49:32 +03:00
Jade Choghari
440e22c184 remove step1 2025-08-28 22:46:18 +03:00
Jade Choghari
28b69bf8ba quick fix 2025-08-28 22:41:00 +03:00
Jade Choghari
b997fdde96 update bash 2025-08-28 18:16:25 +03:00
Jade Choghari
6f975cf576 Merge branch 'main' into add-libero 2025-08-28 18:00:06 +03:00
Jade Choghari
2688731064 Add dep (#4)
* Add 'libero' dependencies to pyproject.toml

* Add Git dependencies for egl_probe and LIBERO

* Update libero-requirements.txt

* add future dep
2025-08-28 17:59:34 +03:00
Pepijn
61b0eeae4b Add feetech firmware update docs (#1793)
* Add feetech firmware update docs

* add bonus

* formatting

* adapt text

* feedback pr
2025-08-28 11:18:54 +02:00
Jade Choghari (jchoghar)
fe20437b62 ad 2025-08-25 14:58:46 -04:00
Jade Choghari (jchoghar)
ff861ba869 add safethread support 2025-08-25 14:52:35 -04:00
mgiac-hexagon
577cd10974 Removed dupicate lines of code (#1709) 2025-08-25 12:39:32 +02:00
pre-commit-ci[bot]
4be3942cbc [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-25 10:26:38 +00:00
Jade Choghari
fd5afdfbf0 Merge branch 'main' into add-libero 2025-08-25 13:25:55 +03:00
Jade Choghari (jchoghar)
8d2c66abd2 final refactor/fix 2025-08-25 06:25:02 -04:00
lxk
b0923ab74b fix(dataset): Use provided episode_data in save_episode (#1740)
The 'episode_data' parameter was previously ignored, causing an error if provided. This change ensures it is correctly used, which allows for asynchronous episode saving by passing a copy of the episode buffer, preventing conflicts with the main data collection loop.
2025-08-22 15:24:02 +02:00
Jack Vial
7f70b78f32 Add missing encoding table entries for Koch arm (#1534) 2025-08-20 17:24:05 +02:00
Jade Choghari
afad90ffaa Update .gitignore 2025-08-20 13:57:57 +03:00
pre-commit-ci[bot]
f5091448a8 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-20 10:56:33 +00:00
Jade Choghari (jchoghar)
cc46497f4c fix renaming issues with cams 2025-08-20 06:55:54 -04:00
pre-commit-ci[bot]
5d25f5bd40 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-19 13:41:16 +00:00
Jade Choghari (jchoghar)
ce83752f16 fix video paths and train.py 2025-08-19 09:39:14 -04:00
pre-commit-ci[bot]
4ed6cf159d [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-17 20:41:34 +00:00
Jade Choghari (jchoghar)
7626d26e6a bug remove 2025-08-17 16:40:11 -04:00
pre-commit-ci[bot]
14a59f576b [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-17 18:33:30 +00:00
Jade Choghari (jchoghar)
eb3649292b remove photos 2025-08-17 14:28:11 -04:00
Jade Choghari (jchoghar)
ac0993c2e3 add multitask 2025-08-17 14:27:53 -04:00
Steven Palma
55198de096 fix(ci): rename libegl1-mesa in deb13 trixie (#1735) 2025-08-14 11:12:06 +02:00
pre-commit-ci[bot]
c20bf75ba0 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-10 05:33:23 +00:00
Jade Choghari
a25480d363 add changes 2025-08-10 01:32:28 -04:00
Jade Choghari
4c19a71d7c Add LIBERO as a submodule 2025-08-10 01:30:19 -04:00
Jade Choghari
d2684d41cd add factory 2025-08-08 09:34:14 -04:00
Jade Choghari
4e76c1f88c Merge branch 'main' into add-libero 2025-08-08 09:24:42 -04:00
pre-commit-ci[bot]
3bf0c19be7 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-06 12:37:41 +00:00
Jade Choghari
ad4f510262 add 2025-08-06 08:37:16 -04:00
pre-commit-ci[bot]
9124b36b0a [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-06 04:06:03 +00:00
Jade Choghari
4bc356b7f3 backup 2025-08-06 00:00:45 -04:00
Jade Choghari
21a961ecbb add libero 2025-08-05 23:55:08 -04:00
176 changed files with 17224 additions and 14327 deletions

View File

@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Install system dependencies and uv (as root)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& mv /root/.local/bin/uv /usr/local/bin/uv \

View File

@@ -19,6 +19,8 @@
title: Train RL in Simulation
- local: async
title: Use Async Inference
- local: libero
title: Using LIBERO
title: "Tutorials"
- sections:
- local: smolvla
@@ -35,10 +37,14 @@
title: Koch v1.1
- local: lekiwi
title: LeKiwi
- local: reachy2
title: Reachy 2
title: "Robots"
- sections:
- local: notebooks
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
title: "Resources"
- sections:
- local: contributing

71
docs/source/feetech.mdx Normal file
View File

@@ -0,0 +1,71 @@
# Feetech Motor Firmware Update
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
## Prerequisites
- Windows computer (Feetech software is only available for Windows)
- Feetech motor control board
- USB cable to connect the control board to your computer
- Feetech motors connected to the control board
## Step 1: Download Feetech Software
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
2. Download the latest version of the Feetech debugging software (FD)
3. Install the software on your Windows computer
## Step 2: Hardware Setup
1. Connect your Feetech motors to the motor control board
2. Connect the motor control board to your Windows computer via USB cable
3. Ensure power is supplied to the motors
## Step 3: Configure Connection
1. Launch the Feetech debugging software
2. Select the correct COM port from the port dropdown menu
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
4. Click "Open" to establish communication with the control board
## Step 4: Scan for Motors
1. Once connected, click the "Search" button to detect all connected motors
2. The software will automatically discover and list all motors on the bus
3. Each motor will appear with its ID number
## Step 5: Update Firmware
For each motor you want to update:
1. **Select the motor** from the list by clicking on it
2. **Click on Upgrade tab**:
3. **Click on Online button**:
- If an potential firmware update is found, it will be displayed in the box
4. **Click on Upgrade button**:
- The update progress will be displayed
## Step 6: Verify Update
1. After the update completes, the software should automatically refresh the motor information
2. Verify that the firmware version has been updated to the expected version
## Important Notes
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
## Bonus: Motor Debugging on Linux/macOS
For debugging purposes only, you can use the open-source Feetech Debug Tool:
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
### Installation Instructions
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
**Limitations:**
- This tool is for debugging and parameter adjustment only
- Firmware updates must still be done on Windows with official Feetech software

View File

@@ -4,13 +4,7 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
It combines three key ingredients:
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
@@ -62,243 +56,30 @@ pip install -e ".[hilserl]"
### Understanding Configuration
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
<!-- prettier-ignore-start -->
```python
class GymManipulatorConfig:
env: HILSerlRobotEnvConfig # Environment configuration (nested)
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
mode: str | None = None # "record", "replay", or None (for training)
device: str = "cpu" # Compute device
class HILSerlRobotEnvConfig(EnvConfig):
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
name: str = "real_robot" # Environment name
task: str | None = None # Task identifier
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
fps: int = 10 # Control frequency
# Nested processor configuration
class HILSerlProcessorConfig:
control_mode: str = "gamepad" # Control mode
observation: ObservationConfig | None = None # Observation processing settings
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
gripper: GripperConfig | None = None # Gripper control and penalty settings
reset: ResetConfig | None = None # Environment reset and timing settings
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
max_gripper_pos: float | None = 100.0 # Maximum gripper position
# Sub-configuration classes
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
resize_size: tuple[int, int] | None = None # Target image size
class GripperConfig:
use_gripper: bool = True # Enable gripper control
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
class ResetConfig:
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
reset_time_s: float = 5.0 # Time to wait during reset
control_time_s: float = 20.0 # Maximum episode duration
terminate_on_success: bool = True # Whether to terminate episodes on success detection
class InverseKinematicsConfig:
urdf_path: str | None = None # Path to robot URDF file
target_frame_name: str | None = None # End-effector frame name
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
class RewardClassifierConfig:
pretrained_path: str | None = None # Path to pretrained reward classifier
success_threshold: float = 0.5 # Success detection threshold
success_reward: float = 1.0 # Reward value for successful episodes
# Dataset configuration
class DatasetConfig:
repo_id: str # LeRobot dataset repository ID
dataset_root: str # Local dataset root directory
task: str # Task identifier
num_episodes: int # Number of episodes for recording
episode: int # Episode index for replay
push_to_hub: bool # Whether to push datasets to Hub
name: str = "real_robot" # Environment name
mode: str = None # "record", "replay", or None (for training)
repo_id: str | None = None # LeRobot dataset repository ID
dataset_root: str | None = None # Local dataset root (optional)
task: str = "" # Task identifier
num_episodes: int = 10 # Number of episodes for recording
episode: int = 0 # episode index for replay
device: str = "cuda" # Compute device
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
pretrained_policy_name_or_path: str | None = None # For policy loading
reward_classifier_pretrained_path: str | None = None # For reward model
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
```
<!-- prettier-ignore-end -->
### Processor Pipeline Architecture
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
#### Environment Processor Pipeline
The environment processor (`env_processor`) handles incoming observations and environment state:
1. **VanillaObservationProcessor**: Converts raw robot observations into standardized format
2. **JointVelocityProcessor** (optional): Adds joint velocity information to observations
3. **MotorCurrentProcessor** (optional): Adds motor current readings to observations
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
5. **ImageCropResizeProcessor** (optional): Crops and resizes camera images
6. **TimeLimitProcessor** (optional): Enforces episode time limits
7. **GripperPenaltyProcessor** (optional): Applies penalties for inappropriate gripper usage
8. **RewardClassifierProcessor** (optional): Automated reward detection using vision models
9. **ToBatchProcessor**: Converts data to batch format for neural network processing
10. **DeviceProcessor**: Moves data to the specified compute device (CPU/GPU)
#### Action Processor Pipeline
The action processor (`action_processor`) handles outgoing actions and human interventions:
1. **AddTeleopActionAsComplimentaryData**: Captures teleoperator actions for logging
2. **AddTeleopEventsAsInfo**: Records intervention events and episode control signals
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
4. **InterventionActionProcessor**: Handles human interventions and episode termination
5. **Inverse Kinematics Pipeline** (when enabled):
- **MapDeltaActionToRobotAction**: Converts delta actions to robot action format
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
- **EEBoundsAndSafety**: Enforces workspace safety bounds
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
- **GripperVelocityToJoint**: Handles gripper control commands
#### Configuration Examples
**Basic Observation Processing**:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": false,
"display_cameras": false
}
}
}
}
```
**Image Processing**:
```json
{
"env": {
"processor": {
"image_preprocessing": {
"crop_params_dict": {
"observation.images.front": [180, 250, 120, 150],
"observation.images.side": [180, 207, 180, 200]
},
"resize_size": [128, 128]
}
}
}
}
```
**Inverse Kinematics Setup**:
```json
{
"env": {
"processor": {
"inverse_kinematics": {
"urdf_path": "path/to/robot.urdf",
"target_frame_name": "end_effector",
"end_effector_bounds": {
"min": [0.16, -0.08, 0.03],
"max": [0.24, 0.2, 0.1]
},
"end_effector_step_sizes": {
"x": 0.02,
"y": 0.02,
"z": 0.02
}
}
}
}
}
```
### Advanced Observation Processing
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
#### Joint Velocity Processing
Enable joint velocity estimation to provide the policy with motion information:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true
}
}
}
}
```
This processor:
- Estimates joint velocities using finite differences between consecutive joint position readings
- Adds velocity information to the observation state vector
- Useful for policies that need motion awareness for dynamic tasks
#### Motor Current Processing
Monitor motor currents to detect contact forces and load conditions:
```json
{
"env": {
"processor": {
"observation": {
"add_current_to_observation": true
}
}
}
}
```
This processor:
- Reads motor current values from the robot's control system
- Adds current measurements to the observation state vector
- Helps detect contact events, object weights, and mechanical resistance
- Useful for contact-rich manipulation tasks
#### Combined Observation Processing
You can enable multiple observation processing features simultaneously:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": true,
"add_ee_pose_to_observation": false,
"display_cameras": false
}
}
}
}
```
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
### Finding Robot Workspace Bounds
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
@@ -349,56 +130,22 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
1. Set `mode` to `"record"` at the root level
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
3. Set `num_episodes` in the `dataset` section to the number of demonstrations you want to collect
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
1. Set `mode` to `"record"`
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
3. Set `num_episodes` to the number of demonstrations you want to collect
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
5. Configure `robot`, `cameras`, and other hardware settings
Example configuration section:
```json
{
"env": {
"type": "gym_manipulator",
"name": "real_robot",
"fps": 10,
"processor": {
"control_mode": "gamepad",
"observation": {
"display_cameras": false
},
"image_preprocessing": {
"crop_params_dict": {},
"resize_size": [128, 128]
},
"gripper": {
"use_gripper": true,
"gripper_penalty": 0.0
},
"reset": {
"reset_time_s": 5.0,
"control_time_s": 20.0
}
},
"robot": {
// ... robot configuration ...
},
"teleop": {
// ... teleoperator configuration ...
}
},
"dataset": {
"repo_id": "username/pick_lift_cube",
"dataset_root": null,
"task": "pick_and_lift",
"num_episodes": 15,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
"device": "cpu"
}
"mode": "record",
"repo_id": "username/pick_lift_cube",
"dataset_root": null,
"task": "pick_and_lift",
"num_episodes": 15,
"episode": 0,
"push_to_hub": true
```
### Using a Teleoperation Device
@@ -444,20 +191,10 @@ The gamepad provides a very convenient way to control the robot and the episode
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
```json
{
"env": {
"teleop": {
"type": "gamepad",
"use_gripper": true
},
"processor": {
"control_mode": "gamepad",
"gripper": {
"type": "gamepad",
"use_gripper": true
}
}
}
}
},
```
<p align="center">
@@ -479,21 +216,11 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
```json
{
"env": {
"teleop": {
"type": "so101_leader",
"port": "/dev/tty.usbmodem585A0077921",
"use_degrees": true
"type": "so101_leader",
"port": "/dev/tty.usbmodem585A0077921", # check your port number
"use_degrees": true
},
"processor": {
"control_mode": "leader",
"gripper": {
"use_gripper": true
}
}
}
}
```
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
@@ -524,7 +251,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
During recording:
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
2. Complete the task successfully
3. The episode ends with a reward of 1 when you press the "success" button
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
@@ -583,19 +310,11 @@ observation.images.front: [180, 250, 120, 150]
Add these crop parameters to your training configuration:
```json
{
"env": {
"processor": {
"image_preprocessing": {
"crop_params_dict": {
"observation.images.side": [180, 207, 180, 200],
"observation.images.front": [180, 250, 120, 150]
},
"resize_size": [128, 128]
}
}
}
}
"crop_params_dict": {
"observation.images.side": [180, 207, 180, 200],
"observation.images.front": [180, 250, 120, 150]
},
"resize_size": [128, 128]
```
**Recommended image resolution**
@@ -624,52 +343,26 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
**Key Parameters for Data Collection**
- **mode**: set it to `"record"` to collect a dataset (at root level)
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
- **dataset.num_episodes**: Number of episodes to record
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
- **env.fps**: Number of frames per second to record
- **dataset.push_to_hub**: Whether to push the dataset to the hub
- **mode**: set it to `"record"` to collect a dataset
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
- **num_episodes**: Number of episodes to record
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
- **fps**: Number of frames per second to record
- **push_to_hub**: Whether to push the dataset to the hub
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
Example configuration section for data collection:
```json
{
"env": {
"type": "gym_manipulator",
"name": "real_robot",
"fps": 10,
"processor": {
"reset": {
"reset_time_s": 5.0,
"control_time_s": 20.0,
"terminate_on_success": false
},
"gripper": {
"use_gripper": true
}
},
"robot": {
// ... robot configuration ...
},
"teleop": {
// ... teleoperator configuration ...
}
},
"dataset": {
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"task": "reward_classifier_task",
"num_episodes": 20,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
"device": "cpu"
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"num_episodes": 20,
"push_to_hub": true,
"fps": 10,
"number_of_steps_after_success": 15
}
```
@@ -728,17 +421,9 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
<!-- prettier-ignore-start -->
```python
config = GymManipulatorConfig(
env=HILSerlRobotEnvConfig(
processor=HILSerlProcessorConfig(
reward_classifier=RewardClassifierConfig(
pretrained_path="path_to_your_pretrained_trained_model"
)
),
# Other environment parameters
),
dataset=DatasetConfig(...),
mode=None # For training
env_config = HILSerlRobotEnvConfig(
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
# Other environment parameters
)
```
<!-- prettier-ignore-end -->
@@ -747,18 +432,7 @@ or set the argument in the json config file.
```json
{
"env": {
"processor": {
"reward_classifier": {
"pretrained_path": "path_to_your_pretrained_model",
"success_threshold": 0.7,
"success_reward": 1.0
},
"reset": {
"terminate_on_success": true
}
}
}
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
}
```

View File

@@ -32,12 +32,9 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"type": "hil",
"name": "franka_sim",
"task": "PandaPickCubeGamepad-v0",
"device": "cuda"
}
```
@@ -48,40 +45,28 @@ Available tasks:
- `PandaPickCubeGamepad-v0`: With gamepad control
- `PandaPickCubeKeyboard-v0`: With keyboard control
### Processor Configuration
### Gym Wrappers Configuration
```json
{
"env": {
"processor": {
"control_mode": "gamepad",
"gripper": {
"use_gripper": true,
"gripper_penalty": -0.02
},
"reset": {
"control_time_s": 15.0,
"fixed_reset_joint_positions": [
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
]
},
"inverse_kinematics": {
"end_effector_step_sizes": {
"x": 0.025,
"y": 0.025,
"z": 0.025
}
}
"wrapper": {
"gripper_penalty": -0.02,
"control_time_s": 15.0,
"use_gripper": true,
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
"end_effector_step_sizes": {
"x": 0.025,
"y": 0.025,
"z": 0.025
},
"control_mode": "gamepad"
}
}
}
```
Important parameters:
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
- `gripper.use_gripper`: Whether to enable gripper control
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
- `gripper_penalty`: Penalty for excessive gripper movement
- `use_gripper`: Whether to enable gripper control
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
## Running with HIL RL of LeRobot
@@ -90,50 +75,39 @@ Important parameters:
To run the environment, set mode to null:
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
<!-- prettier-ignore-end -->
### Recording a Dataset
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0"
},
"dataset": {
"repo_id": "username/sim_dataset",
"dataset_root": null,
"task": "pick_cube",
"num_episodes": 10,
"episode": 0,
"push_to_hub": true
},
"mode": "record"
}
```
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
<!-- prettier-ignore-end -->
### Training a Policy
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
```
<!-- prettier-ignore-end -->
In a different terminal, run the learner server:
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
```
<!-- prettier-ignore-end -->
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.

View File

@@ -519,14 +519,11 @@ from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.record import record_loop
from lerobot.policies.factory import make_processor
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
@@ -538,7 +535,7 @@ robot_config = SO100FollowerConfig(
robot = SO100Follower(robot_config)
# Initialize the policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
@@ -547,7 +544,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
repo_id="<hf_username>/eval_<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -562,12 +559,6 @@ _init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
@@ -577,8 +568,6 @@ for episode_idx in range(NUM_EPISODES):
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,

View File

@@ -24,36 +24,11 @@ pip install -e ".[hilserl]"
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_gym",
"dataset_root": null,
"task": "pick_cube",
"num_episodes": 30,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
"device": "cuda"
}
```
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
Key configuration points:
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
- Set `num_episodes: 30` to collect 30 demonstration episodes
- Ensure `mode` is set to `"record"`
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
Then we can run this command to start:
@@ -165,32 +140,9 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
## Evaluate your policy in Sim
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
Here's an example evaluation configuration:
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_sim_dataset",
"dataset_root": null,
"task": "pick_cube"
},
"pretrained_policy_name_or_path": "your_username/il_sim_model",
"device": "cuda"
}
```
Make sure to replace:
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
Then you can run this command to visualize your trained policy

230
docs/source/libero.mdx Normal file
View File

@@ -0,0 +1,230 @@
# LIBERO
**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots wont just be pretrained once in a factory, theyll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and its a key step toward building robots that become truly personalized helpers. The benchmark was first introduced in the [LIBERO paper](https://arxiv.org/abs/2306.03310) and the [original repository](https://github.com/Lifelong-Robot-Learning/LIBERO).
To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each others work.
LIBERO includes **five task suites**:
- **LIBERO-Spatial (`libero_spatial`)** tasks that require reasoning about spatial relations.
- **LIBERO-Object (`libero_object`)** tasks centered on manipulating different objects.
- **LIBERO-Goal (`libero_goal`)** goal-conditioned tasks where the robot must adapt to changing targets.
- **LIBERO-90 (`libero_90`)** 90 short-horizon tasks from the LIBERO-100 collection.
- **LIBERO-Long (`libero_10`)** 10 long-horizon tasks from the LIBERO-100 collection.
Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms.
![An overview of the LIBERO benchmark](https://libero-project.github.io/assets/img/libero/fig1.png)
_Figure 1: An overview of the LIBERO benchmark._
## Evaluating with LIBERO
At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it primarily to **benchmark [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model, comparing it against state-of-the-art VLA models such as Pi0, OpenVLA, Octo, and Diffusion Policy.
LIBERO is now part of our **multi-eval supported simulation**, allowing you to benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a single flag.
To install LIBERO, first follow the [LeRobot Installation Guide](https://huggingface.co/docs/lerobot/installation).
Once LeRobot is installed, there are two options:
1. **Install via pip** (recommended):
```bash
pip install "lerobot[libero,smolvla]"
```
2. **Install from source**:
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e ".[libero,smolvla]"
```
### Single-suite evaluation
Evaluate a policy on one LIBERO suite:
```bash
python src/lerobot/scripts/eval.py \
--policy.path="your-policy-id" \
--env.type=libero \
--env.task=libero_object \
--env.multitask_eval=False \
--eval.batch_size=2 \
--eval.n_episodes=3
```
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run in total.
---
### Multi-suite evaluation
Benchmark a policy across multiple suites at once:
```bash
python src/lerobot/scripts/eval.py \
--policy.path="your-policy-id" \
--env.type=libero \
--env.task=libero_object \
--env.multitask_eval=True \
--eval.batch_size=1 \
--eval.n_episodes=2
```
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
- Set `-env.multitask_eval=True` to enable evaluation across all tasks in those suites.
### Policy inputs and outputs
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
- **Observations**
- `observation.state` proprioceptive features (agent state).
- `observation.images.image` main camera view (`agentview_image`).
- `observation.images.image2` wrist camera view (`robot0_eye_in_hand_image`).
⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any visual features. Make sure your dataset metadata keys match this convention when evaluating.
## Input Features and Metadata Alignment
To train or evaluate a policy, you use `make_policy`, which builds a feature-naming dictionary for the observations the policy expects.
This mapping can come from:
- Dataset metadata
- The evaluation environment
- The policy path (if a pretrained repo ID is provided)
### Common Issues
A common problem is when the keys in the dataset, environment, and policy config do not match. For example:
- `wrist_image` vs `observation.images.image2`
- `observation.image2` (as in SmolVLA) vs the `.images.*` prefix convention
Such mismatches will cause `KeyError`s. This may be due to assumptions in `make_policy` or missing error handling.
***
### How to Check Expected Features
- Open your policy config (`config.json`), e.g. [example here](https://huggingface.co/jadechoghari/smolvla-libero/blob/main/config.json).
- Or add a breakpoint in `train.py` and inspect:
````python
print(policy.config.input_features)
To ensure you can just check what your policy expects as `input_features`:
- Open your policy config (`config.json`), e.g. [example here](https://huggingface.co/jadechoghari/smolvla-libero/blob/main/config.json).
- Or add a breakpoint in `train.py` and inspect:
```python
print(policy.config.input_features)
Fixing KeyErrors (Preprocessing Example)
````
## Fixing KeyErrors (Preprocessing Example)
If your dataset columns do not follow the expected naming, you can rename them in-place before training:
````python
import pyarrow.parquet as pq
import shutil
def rename_columns(parquet_path, rename_map):
table = pq.read_table(parquet_path)
schema = table.schema
new_names = [rename_map.get(name, name) for name in schema.names]
renamed_table = table.rename_columns(new_names)
backup_path = parquet_path + ".bak"
shutil.copy(parquet_path, backup_path)
pq.write_table(renamed_table, parquet_path)
print(f"patched {parquet_path}, backup at {backup_path}")
# example mapping: align dataset keys to LeRobot convention
rename_map = {
"image": "observation.images.image",
"wrist_image": "observation.images.image2",
}
rename_columns("episode_000001.parquet", rename_map)
- **Actions**
- Continuous control values in a `Box(-1, 1, shape=(7,))` space.
We also provide a notebook for quick testing:
Training with LIBERO
## Training with LIBERO
When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention.
The environment expects:
- `observation.state` → 8-dim agent state
- `observation.images.image` → main camera (`agentview_image`)
- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`)
⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code. We plan to provide a script to easily preprocess such data.
To avoid potential mismatches and `KeyError`s, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulations.
- 🔗 [Preprocessed LIBERO dataset (Hugging Face LeRobot org)](https://huggingface.co/datasets/HuggingFaceVLA/libero)
- 🔗 [Original LIBERO dataset (physical-intelligence)](https://huggingface.co/datasets/physical-intelligence/libero)
The preprocessed dataset follows LeRobot naming conventions (e.g., `.images.*` prefix for visual features) and aligns with policy configs out-of-the-box.
The original dataset is acknowledged here as the primary source.
---
### Example training command
```bash
python src/lerobot/scripts/train.py \
--policy.type=smolvla \
--policy.repo_id=${HF_USER}/libero-test \
--dataset.repo_id=jadechoghari/smol-libero3 \
--env.type=libero \
--env.task=libero_10 \
--output_dir=./outputs/ \
--steps=100000 \
--batch_size=4 \
--env.multitask_eval=True \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval_freq=1000 \
````
---
### Note on rendering
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
---
## Colab Note on Parallel Evaluation
When running evaluation on Colab, you may encounter warnings such as:
```
UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
```
This happens because Colabs rendering contexts are **not thread-safe**, and `ThreadPoolExecutor(max_workers=num_workers)` can trigger segfaults or leaked semaphore warnings.
**Colab Note:**
Parallel evaluation is not supported in Colab. To avoid these issues, run sequentially or disable multitask evaluation:
Run sequentially:
```bash
--env.max_parallel_tasks=1
```
Or disable multitask evaluation:
```bash
--env.multitask_eval=False
```
If you want to take advantage of **parallel evaluation**, we recommend **not using Colab**. Instead, run locally or on a proper compute environment where multi-threaded rendering is easily supported.

288
docs/source/reachy2.mdx Normal file
View File

@@ -0,0 +1,288 @@
# Reachy 2
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
## Teleoperate Reachy 2
Currently, there are two ways to teleoperate Reachy 2:
- Pollen Robotics VR teleoperation (not included in LeRobot).
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
## Reachy 2 Simulation
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
1. Install [Docker Engine](https://docs.docker.com/engine/).
2. Run (for MuJoCo):
```
docker run --rm -it \
--name reachy \
--privileged \
--network host \
--ipc host \
--device-cgroup-rule='c 189:* rwm' \
--group-add audio \
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
-e DISPLAY="$DISPLAY" \
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
-v /dev:/dev \
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
--entrypoint /package/launch.sh \
pollenrobotics/reachy2_core:1.7.5.9_deploy \
start_rviz:=true start_sdk_server:=true mujoco:=true
```
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
>
> ```
> docker run --rm -it \
> --name reachy \
> --privileged \
> --network host \
> --ipc host \
> --device-cgroup-rule='c 189:* rwm' \
> --group-add audio \
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
> -e DISPLAY="$DISPLAY" \
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
> -v /dev:/dev \
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
> --entrypoint /package/launch.sh \
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
> start_rviz:=true start_sdk_server:=true mujoco:=true
> ```
## Setup
### Prerequisites
- On your robot, check the **service images** meet the minimum versions:
- **reachy2-core >= 1.7.5.2**
- **webrtc >= 2.0.1.1**
Then, if you want to use VR teleoperation:
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
Use version **>=v1.2.0**
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
### Install LeRobot
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
Install LeRobot with Reachy 2 dependencies:
```bash
pip install -e ".[reachy2]"
```
### (Optional but recommended) Install pollen_data_acquisition_server
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. Youre free to develop custom clients to manage sessions to your needs.
In your LeRobot environment, install the server from source:
```bash
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
cd pollen_data_acquisition_server
pip install -e .
```
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
## Step 1: Recording
### Get Reachy 2 IP address
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
### Launch recording
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robots motions.
- **Using LeRobots record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, its only for motion control.
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
```bash
python -m pollen_data_acquisition_server.server
```
Then get into the teleoperation application and choose "Data acquisition session".
You can finally setup your session by following the screens displayed.
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
### Option 2: Using lerobot.record
Reachy 2 is fully supported by LeRobots recording features.
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
**Example: start a recording without the mobile base:**
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.id=r2-0000 \
--robot.use_external_commands=true \
--robot.with_mobile_base=false \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
#### Specific Options
**Extended setup overview (all options included):**
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=true \
--robot.with_mobile_base=true \
--robot.with_l_arm=true \
--robot.with_r_arm=true \
--robot.with_neck=true \
--robot.with_antennas=true \
--robot.with_left_teleop_camera=true \
--robot.with_right_teleop_camera=true \
--robot.with_torso_camera=false \
--robot.disable_torque_on_disconnect=false \
--robot.max_relative_target=5.0 \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.use_present_position=false \
--teleop.with_mobile_base=false \
--teleop.with_l_arm=true \
--teleop.with_r_arm=true \
--teleop.with_neck=true \
--teleop.with_antennas=true \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
##### `--robot.use_external_commands`
Determine whether LeRobot robot.send_action() sends commands to the robot.
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
##### `--teleop.use_present_position`
Determine whether the teleoperator reads the goal or present position of the robot.
Must be set to true if a compliant Reachy 2 is used to control another one.
##### Use the relevant parts
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
To avoid this, you can exclude specific parts from recording and replay using:
````
--robot.with_<part>=false
```,
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
It determine whether the corresponding part is recorded in the observations. True if not set.
By default, **all parts are recorded**.
The same per-part mechanism is available in `reachy2_teleoperator` as well.
````
--teleop.with\_<part>
```
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
Determine whether the corresponding part is recorded in the actions. True if not set.
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
##### Use the relevant cameras
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
```
--robot.with_left_teleop_camera=<true|false>
--robot.with_right_teleop_camera=<true|false>
--robot.with_torso_camera=<true|false>
````
## Step 2: Replay
Make sure the robot is configured with the same parts as the dataset:
```bash
python -m lerobot.replay \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=false \
--robot.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.episode=0
--display_data=true
````
## Step 3: Train
```bash
python -m lerobot.scripts.train \
--dataset.repo_id=pollen_robotics/record_test \
--policy.type=act \
--output_dir=outputs/train/reachy2_test \
--job_name=reachy2 \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=pollen_robotics/record_test_policy
```
## Step 4: Evaluate
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--display_data=false \
--dataset.repo_id=pollen_robotics/eval_record_test \
--dataset.single_task="Evaluate reachy2 policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
```

View File

@@ -0,0 +1,58 @@
#!/bin/bash
# storage / caches
RAID=/raid/jade
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
export HF_HOME=$RAID/.cache/huggingface
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
export WANDB_CACHE_DIR=$RAID/.cache/wandb
export TMPDIR=$RAID/.cache/tmp
mkdir -p $TMPDIR
export WANDB_MODE=offline
export HF_DATASETS_OFFLINE=1
export HF_HUB_OFFLINE=1
export TOKENIZERS_PARALLELISM=false
export MUJOCO_GL=egl
export CUDA_VISIBLE_DEVICES=2
# CONFIGURATION
POLICY_PATH="/raid/jade/logs/lerobot/lerobot_2_HuggingFaceVLA_libero_smolvla_lr1e-4bs32steps100000/checkpoints/100000/pretrained_model"
POLICY_PATH="/raid/jade/models/smolvlamust"
TASK=libero_spatial,libero_object
ENV_TYPE="libero"
BATCH_SIZE=1
N_EPISODES=1
# storage / caches
RAID=/raid/jade
N_ACTION_STEPS=1
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
export HF_HOME=$RAID/.cache/huggingface
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
export WANDB_CACHE_DIR=$RAID/.cache/wandb
export TMPDIR=$RAID/.cache/tmp
mkdir -p $TMPDIR
export WANDB_MODE=offline
# export HF_DATASETS_OFFLINE=1
# export HF_HUB_OFFLINE=1
export TOKENIZERS_PARALLELISM=false
export MUJOCO_GL=egl
export MUJOCO_GL=egl
unset HF_HUB_OFFLINE
# RUN EVALUATION
python src/lerobot/scripts/eval.py \
--policy.path="$POLICY_PATH" \
--env.type="$ENV_TYPE" \
--eval.batch_size="$BATCH_SIZE" \
--eval.n_episodes="$N_EPISODES" \
--env.multitask_eval=True \
--env.task=$TASK \
# python examples/evaluate_libero.py \
# --policy_path "$POLICY_PATH" \
# --task_suite_name "$TASK" \
# --num_steps_wait 10 \
# --num_trials_per_task 10 \
# --video_out_path "data/libero/videos" \
# --device "cuda" \
# --seed 7

View File

@@ -1,7 +1,6 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_processor
from lerobot.record import record_loop
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.utils.control_utils import init_keyboard_listener
@@ -12,14 +11,12 @@ NUM_EPISODES = 2
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
robot = LeKiwiClient(robot_config)
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
@@ -28,7 +25,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
repo_id="<hf_username>/<eval_dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -46,12 +43,6 @@ listener, events = init_keyboard_listener()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
@@ -62,8 +53,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,

View File

@@ -38,7 +38,7 @@ while True:
keyboard_keys = keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
log_rerun_data(observation=observation, action={**arm_action, **base_action})
log_rerun_data(observation, {**arm_action, **base_action})
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action

View File

@@ -1,158 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import merge_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_processor
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
)
from lerobot.processor.pipeline import RobotProcessor
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot with degrees
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Initialize the robot
robot = SO100Follower(robot_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessor(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=to_transition_robot_observation,
to_output=lambda tr: tr,
)
# Build dataset action and gripper features
action_ee_and_gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints,
initial_features={},
use_videos=True,
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
) # Get all ee action features + gripper pos action features
# Build dataset observation features
obs_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=robot.observation_features,
use_videos=True,
patterns=["observation.state.ee"],
) # Get all ee observation features
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
print("All dataset features: ", dataset_features)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording_phone")
# Connect the robot and teleoperator
robot.connect()
episode_idx = 0
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
robot_action_processor=robot_ee_to_joints,
robot_observation_processor=robot_joints_to_ee_pose,
)
dataset.save_episode()
# Clean up
log_say("Stop recording")
robot.disconnect()
dataset.push_to_hub()

View File

@@ -1,215 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import merge_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
)
from lerobot.processor.pipeline import RobotProcessor
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
EEBoundsAndSafety,
EEReferenceAndDelta,
ForwardKinematicsJointsToEE,
GripperVelocityToJoint,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone import Phone
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
NUM_EPISODES = 10
FPS = 30
EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 30
TASK_DESCRIPTION = "My task description"
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot and teleoperator
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose = RobotProcessor(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.20,
max_ee_twist_step_rad=0.50,
),
],
to_transition=to_transition_teleop_action,
to_output=lambda tr: tr,
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
GripperVelocityToJoint(
motor_names=list(robot.bus.motors.keys()),
speed_factor=20.0,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessor(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=to_transition_robot_observation,
to_output=lambda tr: tr,
)
# Build dataset ee action features
action_ee = aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose,
initial_features=phone.action_features,
use_videos=True,
patterns=["action.ee"],
)
# Get gripper pos action features
gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints,
initial_features={},
use_videos=True,
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
)
# Build dataset ee observation features
observation_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=robot.observation_features,
use_videos=True,
patterns=["observation.state.ee"],
)
dataset_features = merge_features(action_ee, gripper, observation_ee)
print("All dataset features: ", dataset_features)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording_phone")
# Connect the robot and teleoperator
robot.connect()
phone.connect()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose,
robot_action_processor=robot_ee_to_joints,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose,
robot_action_processor=robot_ee_to_joints,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
dataset.push_to_hub()

View File

@@ -1,106 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.converters import to_output_robot_action
from lerobot.processor.pipeline import RobotProcessor
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
)
robot = SO100Follower(robot_config)
robot.connect()
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
actions = dataset.hf_dataset.select_columns("action")
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# This method converts the action from the dataset to a transition for pipeline
def action_to_transition(action: dict):
act = {}
# EE pose
for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"):
if k in action:
act[f"action.{k}"] = float(action[k])
# Gripper: your dataset has absolute position
if "gripper.pos" in action:
act["action.gripper.pos"] = float(action["gripper.pos"])
return {
"observation": None,
"action": act,
"reward": None,
"done": False,
"truncated": False,
"info": {},
"complementary_data": {},
}
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=action_to_transition,
to_output=to_output_robot_action,
)
robot_ee_to_joints.reset()
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
ee_action = {
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
}
joint_action = robot_ee_to_joints(ee_action)
action_sent = robot.send_action(joint_action)
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
robot.disconnect()

View File

@@ -1,109 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specif
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
EEBoundsAndSafety,
EEReferenceAndDelta,
GripperVelocityToJoint,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone import Phone
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
# Initialize the robot and teleoperator
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop_device = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose = RobotProcessor(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
],
to_transition=to_transition_teleop_action,
to_output=lambda tr: tr,
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
),
GripperVelocityToJoint(
motor_names=list(robot.bus.motors.keys()),
speed_factor=20.0,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
)
robot.connect()
teleop_device.connect()
print("Starting teleop loop. Move your phone to teleoperate the robot.")
while True:
phone_obs = teleop_device.get_action()
if not phone_obs:
time.sleep(0.01)
continue
# Get teleop observation
phone_obs = teleop_device.get_action()
# Phone to EE pose transition
ee_transition = phone_to_robot_ee_pose(phone_obs)
# EE pose to Joints transition
joint_action = robot_ee_to_joints(ee_transition)
if joint_action:
robot.send_action(joint_action)
time.sleep(0.01)

193
push_pi0_to_hub.py Normal file
View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python
"""Script to create and push a PI0OpenPI model to HuggingFace hub with proper config format."""
import tempfile
from pathlib import Path
import torch
from huggingface_hub import HfApi, create_repo
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
def create_and_push_model(
repo_id: str,
private: bool = False,
token: str = None,
):
"""Create a PI0OpenPI model with proper config and push to HuggingFace hub.
Args:
repo_id: HuggingFace repository ID (e.g., "username/model-name")
private: Whether to create a private repository
token: HuggingFace API token (optional, will use cached token if not provided)
"""
print("=" * 60)
print("PI0OpenPI Model Hub Upload")
print("=" * 60)
# Create configuration
print("\nCreating PI0OpenPI configuration...")
config = PI0OpenPIConfig(
# Model architecture
paligemma_variant="gemma_2b",
action_expert_variant="gemma_300m",
pi05=False, # Use PI0 (not PI0.5)
dtype="float32", # Use float32 for compatibility
# Input/output dimensions
action_dim=32, # see openpi `Pi0Config`
state_dim=32,
chunk_size=50,
n_action_steps=50,
# Image inputs, see openpi `model.py, IMAGE_KEYS`
image_keys=(
"observation.images.base_0_rgb",
"observation.images.left_wrist_0_rgb",
"observation.images.right_wrist_0_rgb",
),
# Training settings
gradient_checkpointing=False,
compile_model=False,
device=None, # Auto-detect
# Tokenizer settings
tokenizer_max_length=48, # see openpi `__post_init__`, use pi0=48 and pi05=200
)
print(f" - Config type: {config.__class__.__name__}")
print(f" - PaliGemma variant: {config.paligemma_variant}")
print(f" - Action expert variant: {config.action_expert_variant}")
print(f" - Action dim: {config.action_dim}")
print(f" - State dim: {config.state_dim}")
# Create dummy dataset stats for normalization
print("\nCreating dataset statistics...")
dataset_stats = {
"observation.state": {
"mean": torch.zeros(config.state_dim),
"std": torch.ones(config.state_dim),
"min": torch.full((config.state_dim,), -5.0),
"max": torch.full((config.state_dim,), 5.0),
},
"action": {
"mean": torch.zeros(config.action_dim),
"std": torch.ones(config.action_dim),
"min": torch.full((config.action_dim,), -1.0),
"max": torch.full((config.action_dim,), 1.0),
},
}
# Add image stats
for key in config.image_keys:
dataset_stats[key] = {
"mean": torch.tensor([0.485, 0.456, 0.406]), # TODO(pepijn): fix this, now its ImageNet mean
"std": torch.tensor([0.229, 0.224, 0.225]), # TODO(pepijn): fix this, now its ImageNet std
"min": torch.tensor([0.0, 0.0, 0.0]),
"max": torch.tensor([1.0, 1.0, 1.0]),
}
# Create the policy
print("\nInitializing PI0OpenPI policy...")
print(" (This may take a moment as it loads the tokenizer and initializes the model)")
policy = PI0OpenPIPolicy(config, dataset_stats)
# Initialize with small random weights (optional - for testing)
# Note: In practice, you would load your trained weights here
print("\nInitializing model weights...")
for name, param in policy.named_parameters():
if "weight" in name:
if "norm" in name.lower() or "layernorm" in name.lower():
torch.nn.init.ones_(param)
elif len(param.shape) >= 2:
torch.nn.init.xavier_uniform_(param, gain=0.01)
else:
torch.nn.init.normal_(param, mean=0.0, std=0.01)
elif "bias" in name:
torch.nn.init.zeros_(param)
print(f" - Total parameters: {sum(p.numel() for p in policy.parameters()):,}")
print(f" - Trainable parameters: {sum(p.numel() for p in policy.parameters() if p.requires_grad):,}")
# Create temporary directory for saving
with tempfile.TemporaryDirectory() as tmpdir:
save_path = Path(tmpdir) / "model"
save_path.mkdir(exist_ok=True)
print(f"\nSaving model to temporary directory: {save_path}")
# Save the model using LeRobot's save_pretrained method
# This ensures the config is saved in the correct format
policy.save_pretrained(save_path)
# List saved files
saved_files = list(save_path.glob("*"))
print("\nSaved files:")
for file in saved_files:
size = file.stat().st_size
print(f" - {file.name}: {size:,} bytes")
# Create or get repository
print(f"\nCreating/accessing repository: {repo_id}")
api = HfApi(token=token)
try:
# Create repo if it doesn't exist
create_repo(
repo_id,
private=private,
token=token,
exist_ok=True,
)
print(f" ✓ Repository ready: https://huggingface.co/{repo_id}")
except Exception as e:
print(f" ⚠️ Note: {e}")
# Upload to hub
print("\nUploading to HuggingFace hub...")
api.upload_folder(
folder_path=str(save_path),
repo_id=repo_id,
repo_type="model",
token=token,
commit_message="Upload PI0OpenPI model with proper LeRobot config format",
)
print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}")
print("\n" + "=" * 60)
print("✓ Process complete!")
print("=" * 60)
return policy
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Push PI0OpenPI model to HuggingFace hub")
parser.add_argument(
"--repo-id",
type=str,
default="test-user/pi0-openpi-test",
help="HuggingFace repository ID (e.g., 'username/model-name')",
)
parser.add_argument(
"--private",
action="store_true",
help="Create a private repository",
)
parser.add_argument(
"--token",
type=str,
default=None,
help="HuggingFace API token (optional, uses cached token if not provided)",
)
args = parser.parse_args()
# Run the upload
create_and_push_model(
repo_id=args.repo_id,
private=args.private,
token=args.token,
)

View File

@@ -29,7 +29,7 @@ version = "0.3.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
requires-python = ">=3.10"
requires-python = ">=3.11"
authors = [
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
@@ -50,7 +50,7 @@ classifiers = [
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Software Development :: Build Tools",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
@@ -73,7 +73,6 @@ dependencies = [
"pynput>=1.7.7",
"pyserial>=3.5",
"wandb>=0.20.0",
"scipy>=1.15.2",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
@@ -96,7 +95,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers<=4.52.0"]
transformers-dep = ["transformers==4.53.2"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
@@ -107,12 +106,12 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
reachy2 = ["reachy2_sdk>=1.0.14"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
# stretch = [
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
@@ -136,13 +135,33 @@ video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
aloha = ["gym-aloha>=0.1.1"]
pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
xarm = ["gym-xarm>=0.1.1"]
libero = [
"hydra-core>=1.2,<1.4",
"numpy",
"wandb",
"easydict",
"transformers",
"opencv-python",
"robomimic==0.2.0",
"einops",
"thop",
"robosuite==1.4.0",
"mujoco>=2.3.7,<3.0.0",
"bddl==1.0.1",
"matplotlib",
"cloudpickle",
"future",
"gym",
"egl_probe @ git+https://github.com/jadechoghari/egl_probe.git#egg=egl_probe",
"libero @ git+https://github.com/jadechoghari/LIBERO.git@main#egg=libero",
]
# All
all = [
"lerobot[dynamixel]",
"lerobot[gamepad]",
"lerobot[hopejr]",
"lerobot[lekiwi]",
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi0]",
@@ -155,7 +174,7 @@ all = [
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[xarm]",
"lerobot[phone]",
"lerobot[libero]"
]
[project.scripts]
@@ -261,7 +280,7 @@ default.extend-ignore-identifiers-re = [
# paths = ["src/lerobot"]
# [tool.mypy]
# python_version = "3.10"
# python_version = "3.11"
# warn_return_any = true
# warn_unused_configs = true
# ignore_missing_imports = false

View File

@@ -1,6 +1,4 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_phone import PhoneConfig
from .phone import Phone
from .configuration_reachy2_camera import Reachy2CameraConfig
from .reachy2_camera import Reachy2Camera

View File

@@ -0,0 +1,78 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
class Reachy2CameraConfig(CameraConfig):
"""Configuration class for Reachy 2 camera devices.
This class provides configuration options for Reachy 2 cameras,
supporting both the teleop and depth cameras. It includes settings
for resolution, frame rate, color mode, and the selection of the cameras.
Example configurations:
```python
# Basic configurations
Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address="192.168.0.200", # IP address of the robot
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
) # Left teleop camera, 640x480 @ 15FPS
```
Attributes:
name: Name of the camera device. Can be "teleop" or "depth".
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
fps: Requested frames per second for the color stream.
width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
ip_address: IP address of the robot. Defaults to "localhost".
port: Port number for the camera server. Defaults to 50065.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
"""
name: str
image_type: str
color_mode: ColorMode = ColorMode.RGB
ip_address: str | None = "localhost"
port: int = 50065
# use_depth: bool = False
def __post_init__(self):
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
self.name == "depth" and self.image_type not in ["rgb", "depth"]
):
raise ValueError(
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)

View File

@@ -0,0 +1,288 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
"""
import logging
import os
import platform
import time
from threading import Event, Lock, Thread
from typing import Any
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
from lerobot.errors import DeviceNotConnectedError
from ..camera import Camera
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
logger = logging.getLogger(__name__)
class Reachy2Camera(Camera):
"""
Manages Reachy 2 camera using Reachy 2 CameraManager.
This class provides a high-level interface to connect to, configure, and read
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
frame reading.
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
type (e.g., "left") to be specified in the configuration.
The camera's default settings (FPS, resolution, color mode) are used unless
overridden in the configuration.
"""
def __init__(self, config: Reachy2CameraConfig):
"""
Initializes the Reachy2Camera instance.
Args:
config: The configuration settings for the camera.
"""
super().__init__(config)
self.config = config
self.fps = config.fps
self.color_mode = config.color_mode
self.cam_manager: CameraManager | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
@property
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
elif self.config.name == "depth":
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
def connect(self, warmup: bool = True):
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
self.cam_manager.initialize_cameras()
logger.info(f"{self} connected.")
@staticmethod
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
"""
Detects available Reachy 2 cameras.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'name', 'stereo',
and the default profile properties (width, height, fps).
"""
initialized_cameras = []
camera_manager = CameraManager(host=ip_address, port=port)
for camera in [camera_manager.teleop, camera_manager.depth]:
if camera is None:
continue
height, width, _, _, _, _, _ = camera.get_parameters()
camera_info = {
"name": camera._cam_info.name,
"stereo": camera._cam_info.stereo,
"default_profile": {
"width": width,
"height": height,
"fps": 30,
},
}
initialized_cameras.append(camera_info)
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
color mode and applying any configured rotation.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
frame = None
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
else:
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.config.image_type == "left":
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
elif self.config.image_type == "right":
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.config.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()[0]
elif self.config.image_type == "rgb":
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8)
if self.config.color_mode == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
while not self.stop_event.is_set():
try:
color_image = self.read()
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
(height, width, channels), processed according to configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
return frame
def disconnect(self):
"""
Stops the background read thread (if running).
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
if self.cam_manager is not None:
self.cam_manager.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
from .realsense.camera_realsense import RealSenseCamera
cameras[key] = RealSenseCamera(cfg)
elif cfg.type == "reachy2_camera":
from .reachy2_camera.reachy2_camera import Reachy2Camera
cameras[key] = Reachy2Camera(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
return cameras

View File

@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@@ -53,6 +53,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
@@ -71,9 +72,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: str | None = None
def __post_init__(self):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")

View File

@@ -24,7 +24,6 @@ class FeatureType(str, Enum):
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
LANGUAGE = "LANGUAGE"
class NormalizationMode(str, Enum):

View File

@@ -21,7 +21,6 @@ OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
OBS_LANGUAGE = "observation.language"
ACTION = "action"
REWARD = "next.reward"
@@ -40,9 +39,6 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"

View File

@@ -825,6 +825,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
if not episode_data:
episode_buffer = self.episode_buffer
else:
episode_buffer = episode_data
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)

View File

@@ -1,94 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor.pipeline import RobotProcessor
def aggregate_pipeline_dataset_features(
pipeline: RobotProcessor,
initial_features: dict[str, Any],
*,
use_videos: bool = True,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates the pipeline's features and returns a features dict ready for the dataset,
filtered to only those keys matching any of the given patterns (for action/state only).
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
- `use_videos`: whether to treat image features as video streams
- `patterns`: regexes to filter action & state features; images are included
whenever use_videos=True, regardless of patterns.
"""
import re
# Gather everything the pipeline features specifies, seeded with hardware cams:
all_features = pipeline.transform_features(initial_features)
# Helper to decide which action/state keys survive the `patterns` filter:
def keep(key: str) -> bool:
if patterns is None:
return True
return any(re.search(pat, key) for pat in patterns)
# Start with hardware dict, injecting initial cameras if videos are ON:
hw: dict[str, dict[str, Any]] = {}
if use_videos:
cams = {
name: shape
for name, shape in initial_features.items()
if isinstance(shape, tuple) and len(shape) == 3
}
if cams:
hw["observation"] = dict(cams)
# Go over every feature from the pipeline and merge:
for full_key, ty in all_features.items():
if full_key.startswith("action."):
# action.<feat>
if not keep(full_key):
continue
name = full_key[len("action.") :]
hw.setdefault("action", {})[name] = ty
elif full_key.startswith("observation.state."):
# observation.state.<feat>
if not keep(full_key):
continue
name = full_key[len("observation.state.") :]
hw.setdefault("observation", {})[name] = ty
elif full_key.startswith("observation.images."):
# observation.images.<cam>
# images obey ONLY the use_videos flag, not patterns
if not use_videos:
continue
name = full_key[len("observation.images.") :]
hw.setdefault("observation", {})[name] = ty
else:
# anything else (e.g. policy-only features) is ignored here
continue
out: dict[str, dict] = {}
if "action" in hw:
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
if "observation" in hw:
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
return out

View File

@@ -470,50 +470,6 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
return policy_features
def merge_features(*dicts: dict) -> dict:
"""
Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (observation.images.*), last one wins (if they are identical).
"""
out: dict = {}
for d in dicts:
for key, value in d.items():
if not isinstance(value, dict):
out[key] = value
continue
dtype = value.get("dtype")
shape = value.get("shape")
is_vector = (
dtype not in ("image", "video", "string")
and isinstance(shape, tuple)
and len(shape) == 1
and "names" in value
)
if is_vector:
# Initialize or retrieve the accumulating dict for this feature key
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
# Ensure consistent data types across merged entries
if "dtype" in target and dtype != target["dtype"]:
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
# Merge feature names: append only new ones to preserve order without duplicates
seen = set(target["names"])
for n in value["names"]:
if n not in seen:
target["names"].append(n)
seen.add(n)
# Recompute the shape to reflect the updated number of features
target["shape"] = (len(target["names"]),)
else:
# For images/videos and non-1D entries: override with the latest definition
out[key] = value
return out
def create_empty_dataset_info(
codebase_version: str,
fps: int,

View File

@@ -30,6 +30,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
fps: int = 30
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)
multitask_eval: bool = False
max_parallel_tasks: int = 5
@property
def type(self) -> str:
@@ -161,71 +163,33 @@ class XarmEnv(EnvConfig):
@dataclass
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class RewardClassifierConfig:
"""Configuration for reward classification."""
pretrained_path: str | None = None
success_threshold: float = 0.5
success_reward: float = 1.0
@dataclass
class InverseKinematicsConfig:
"""Configuration for inverse kinematics processing."""
urdf_path: str | None = None
target_frame_name: str | None = None
end_effector_bounds: dict[str, list[float]] | None = None
end_effector_step_sizes: dict[str, float] | None = None
@dataclass
class ObservationConfig:
"""Configuration for observation processing."""
class EnvTransformConfig:
"""Configuration for environment wrappers."""
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
control_mode: str = "gamepad"
display_cameras: bool = False
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
@dataclass
class GripperConfig:
"""Configuration for gripper control and penalties."""
use_gripper: bool = True
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@dataclass
class ResetConfig:
"""Configuration for environment reset behavior."""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Any | None = None
reset_time_s: float = 5.0
control_time_s: float = 20.0
terminate_on_success: bool = True
@dataclass
class HILSerlProcessorConfig:
"""Configuration for environment processing pipeline."""
control_mode: str = "gamepad"
observation: ObservationConfig | None = None
image_preprocessing: ImagePreprocessingConfig | None = None
gripper: GripperConfig | None = None
reset: ResetConfig | None = None
inverse_kinematics: InverseKinematicsConfig | None = None
reward_classifier: RewardClassifierConfig | None = None
max_gripper_pos: float | None = 100.0
use_gripper: bool = True
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@@ -235,10 +199,127 @@ class HILSerlRobotEnvConfig(EnvConfig):
robot: RobotConfig | None = None
teleop: TeleoperatorConfig | None = None
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
wrapper: EnvTransformConfig | None = None
fps: int = 10
name: str = "real_robot"
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
task: str | None = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None
reward_classifier_pretrained_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
@property
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("hil")
@dataclass
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
name: str = "PandaPickCube"
task: str | None = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
state_dim: int = 18
action_dim: int = 4
fps: int = 100
episode_length: int = 100
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_STATE,
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: str | None = None
robot_config: RobotConfig | None = None
teleop_config: TeleoperatorConfig | None = None
wrapper: EnvTransformConfig | None = None
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
############################
@property
def gym_kwargs(self) -> dict:
return {
"use_viewer": self.use_viewer,
"use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty,
}
@EnvConfig.register_subclass("libero")
@dataclass
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
fps: int = 30
episode_length: int = 520
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
init_states: bool = True
multitask_eval: bool = True
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}

View File

@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, LiberoEnv, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,11 +27,17 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
def make_env(
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> gym.vector.VectorEnv | dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config.
Args:
@@ -46,24 +52,43 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
Returns:
gym.vector.VectorEnv: The parallelized gym.env instance.
dict[str, dict[int, gym.vector.VectorEnv]]: A mapping from task suite
names to indexed vectorized environments (when multitask eval is used).
"""
if n_envs < 1:
raise ValueError("`n_envs must be at least 1")
raise ValueError("`n_envs` must be at least 1")
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs
return create_libero_envs(
task=cfg.task,
n_envs=n_envs,
camera_name=cfg.camera_name,
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
multitask_eval=cfg.multitask_eval,
)
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
raise ModuleNotFoundError(
f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"'
) from e
gym_handle = f"{package_name}/{cfg.task}"
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
def _make_one():
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
return env
vec = env_cls([_make_one for _ in range(n_envs)])
# normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha"
return {suite_name: {0: vec}}

497
src/lerobot/envs/libero.py Normal file
View File

@@ -0,0 +1,497 @@
from __future__ import annotations
import logging
import math
import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping, Sequence
from itertools import chain
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
logger = logging.getLogger(__name__)
# ---- Helpers -----------------------------------------------------------------
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()]
else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
if not cams:
raise ValueError("camera_name resolved to an empty list.")
return cams
def _get_suite(name: str):
"""Instantiate a LIBERO suite by name with clear validation."""
bench = benchmark.get_benchmark_dict()
if name not in bench:
raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}")
suite = bench[name]()
if not getattr(suite, "tasks", None):
raise ValueError(f"Suite '{name}' has no tasks.")
return suite
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
"""Validate/normalize task ids. If None → all tasks."""
if task_ids is None:
return list(range(total_tasks))
ids = sorted({int(t) for t in task_ids})
for t in ids:
if t < 0 or t >= total_tasks:
raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
return ids
def _make_env_fns(
*,
suite,
suite_name: str,
task_id: int,
n_envs: int,
camera_names: list[str],
init_states: bool,
gym_kwargs: Mapping[str, Any],
LiberoEnv: type, # injected to avoid forward ref issues if needed
) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id)."""
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
fns: list[Callable[[], LiberoEnv]] = []
for i in range(n_envs):
def _mk(
i=i,
suite=suite,
task_id=task_id,
suite_name=suite_name,
joined_cams=joined_cams,
init_states=init_states,
gym_kwargs=dict(gym_kwargs),
):
return LiberoEnv(
task_suite=suite,
task_id=task_id,
task_suite_name=suite_name,
camera_name=joined_cams,
init_states=init_states,
episode_index=i,
**gym_kwargs,
)
fns.append(_mk)
return fns
# ---- Main API ----------------------------------------------------------------
def create_libero_envs(
task: str,
n_envs: int,
gym_kwargs: dict[str, Any] | None = None,
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
multitask_eval: bool = True, # kept for signature compatibility; return type is consistent regardless
) -> dict[str, dict[int, Any]]:
"""
Create vectorized LIBERO environments with a consistent return shape.
Returns:
dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
Notes:
- n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
- `task` can be a single suite or a comma-separated list of suites.
- You may pass `task_ids` (list[int]) inside `gym_kwargs` to restrict tasks per suite.
"""
if env_cls is None or not callable(env_cls):
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
if not isinstance(n_envs, int) or n_envs <= 0:
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
gym_kwargs = dict(gym_kwargs or {})
task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks
# Avoid circular import/type issues: assume LiberoEnv is defined in this module
try:
LiberoEnv # type: ignore[name-defined]
except NameError:
# If LiberoEnv is in the same file, this won't run. If it's elsewhere, import here.
exit()
# from .libero_env import LiberoEnv # adjust if your class lives in another module
camera_names = _parse_camera_names(camera_name)
suite_names = [s.strip() for s in str(task).split(",") if s.strip()]
if not suite_names:
raise ValueError("`task` must contain at least one LIBERO suite name.")
logger.info(
"Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s",
suite_names,
n_envs,
init_states,
bool(multitask_eval),
)
if task_ids_filter is not None:
logger.info("Restricting to task_ids=%s", task_ids_filter)
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
for tid in selected:
fns = _make_env_fns(
suite=suite,
suite_name=suite_name,
task_id=tid,
n_envs=n_envs,
camera_names=camera_names,
init_states=init_states,
gym_kwargs=gym_kwargs,
LiberoEnv=LiberoEnv,
)
out[suite_name][tid] = env_cls(fns)
logger.debug("Built vec env | suite=%s | task_id=%d | n_envs=%d", suite_name, tid, n_envs)
# return plain dicts for predictability
return {suite: dict(task_map) for suite, task_map in out.items()}
def quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
Converts quaternion to axis-angle format.
Returns a unit vector direction scaled by its angle in radians.
Args:
quat (np.array): (x,y,z,w) vec4 float angles
Returns:
np.array: (ax,ay,az) axis-angle exponential coordinates
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def get_task_init_states(task_suite, i):
init_states_path = os.path.join(
get_libero_path("init_states"),
task_suite.tasks[i].problem_folder,
task_suite.tasks[i].init_states_file,
)
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
return init_states
def get_libero_dummy_action():
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
return [0, 0, 0, 0, 0, 0, -1]
OBS_STATE_DIM = 8
ACTION_DIM = 7
class LiberoEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
def __init__(
self,
task_suite,
task_id,
task_suite_name,
camera_name="agentview_image,robot0_eye_in_hand_image",
obs_type="pixels",
render_mode="rgb_array",
observation_width=256,
observation_height=256,
visualization_width=640,
visualization_height=480,
init_states=True,
episode_index=0,
):
super().__init__()
self.task_id = task_id
self.obs_type = obs_type
self.render_mode = render_mode
self.observation_width = observation_width
self.observation_height = observation_height
self.visualization_width = visualization_width
self.visualization_height = visualization_height
self.init_states = init_states
self.camera_name = camera_name.split(
","
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
# Map raw camera names to "image1" and "image2".
# The preprocessing step `preprocess_observation` will then prefix these with `.images.*`,
# following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`).
# This ensures the policy consistently receives observations in the
# expected format regardless of the original camera naming.
self.camera_name_mapping = {
"agentview_image": "image",
"robot0_eye_in_hand_image": "image2",
}
self.num_steps_wait = (
10 # Do nothing for the first few timesteps to wait for the simulator drops objects
)
self.episode_index = episode_index
self._env = self._make_envs_task(task_suite, self.task_id)
TASK_SUITE_MAX_STEPS: dict[str, int] = {
"libero_spatial": 220, # longest training demo has 193 steps
"libero_object": 280, # longest training demo has 254 steps
"libero_goal": 300, # longest training demo has 270 steps
"libero_10": 520, # longest training demo has 505 steps
"libero_90": 400, # longest training demo has 373 steps
}
default_steps = 500
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
images = {}
for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
if self.obs_type == "state":
raise NotImplementedError(
"The 'state' observation type is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
)
elif self.obs_type == "pixels":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
}
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(OBS_STATE_DIM,),
dtype=np.float64,
),
}
)
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
return image
def _make_envs_task(self, task_suite, task_id: int = 0):
task = task_suite.get_task(task_id)
self.task = task.name
self.task_description = task.language
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
env_args = {
"bddl_file_name": task_bddl_file,
"camera_heights": self.observation_height,
"camera_widths": self.observation_width,
}
env = OffScreenRenderEnv(**env_args)
env.reset()
if self.init_states:
init_states = get_task_init_states(
task_suite, task_id
) # for benchmarking purpose, we fix the a set of initial states FIXME(mshukor): should be in the reset()?
init_state_id = self.episode_index # episode index
env.set_init_state(init_states[init_state_id])
return env
def _format_raw_obs(self, raw_obs):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
if self.obs_type == "state":
raise NotImplementedError(
"The 'state' observation type is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
)
elif self.obs_type == "pixels":
obs = {"pixels": images.copy()}
elif self.obs_type == "pixels_agent_pos":
obs = {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
return obs
def reset(self, seed=None, **kwargs):
super().reset(seed=seed)
self._env.seed(seed)
raw_obs = self._env.reset()
# Do nothing for the first few timesteps to wait for the simulator drops objects
for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
observation = self._format_raw_obs(raw_obs)
info = {"is_success": False}
return observation, info
def step(self, action):
if action.ndim != 1:
raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), "
f"but got shape {action.shape} with ndim={action.ndim}"
)
raw_obs, reward, done, info = self._env.step(action)
is_success = self._env.check_success()
terminated = done or is_success
info["is_success"] = done # is_success
observation = self._format_raw_obs(raw_obs)
if done:
self.reset()
print(self.task, self.task_id, done, is_success)
truncated = False
return observation, reward, terminated, truncated, info
def close(self):
self._env.close()
def create_libero_envs1(
task: str,
n_envs: int,
gym_kwargs: dict[str, Any] = None,
camera_name: str = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True,
env_cls: Callable = None,
multitask_eval: bool = True,
) -> dict[str, dict[str, Any]]:
"""
Here n_envs is per task and equal to the number of rollouts.
Returns:
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
"""
print("num envs", n_envs)
print("multitask_eval", multitask_eval)
print("gym_kwargs", gym_kwargs)
if gym_kwargs is None:
gym_kwargs = {}
if not multitask_eval:
benchmark_dict = benchmark.get_benchmark_dict()
task_suite = benchmark_dict[task]() # can also choose libero_spatial, libero_object, libero_10 etc.
tasks_id = list(range(len(task_suite.tasks)))
episode_indices = [0 for i in range(len(tasks_id))]
if len(tasks_id) == 1:
tasks_id = [tasks_id[0] for _ in range(n_envs)]
episode_indices = list(range(n_envs))
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
n_repeat = n_envs // len(tasks_id)
print("n_repeat", n_repeat)
episode_indices = []
for _ in range(len(tasks_id)):
episode_indices.extend(list(range(n_repeat)))
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
elif n_envs < len(tasks_id):
tasks_id = tasks_id[:n_envs]
episode_indices = list(range(n_envs))[:n_envs]
print(f"WARNING: n_envs < len(tasks_id), evaluating only on {tasks_id}")
print(f"Creating Libero envs with task ids {tasks_id} from suite {task}")
assert n_envs == len(tasks_id), (
f"len(n_envs) and tasks_id should be the same, got {n_envs} and {len(tasks_id)}"
)
return env_cls(
[
lambda i=i: LiberoEnv(
task_suite=task_suite,
task_id=tasks_id[i],
task_suite_name=task,
camera_name=camera_name,
init_states=init_states,
episode_index=episode_indices[i],
**gym_kwargs,
)
for i in range(n_envs)
]
)
else:
envs = defaultdict(dict)
benchmark_dict = benchmark.get_benchmark_dict()
task = task.split(",")
for _task in task:
task_suite = benchmark_dict[
_task
]() # can also choose libero_spatial, libero_object, libero_10 etc.
tasks_ids = list(range(len(task_suite.tasks)))
for tasks_id in tasks_ids:
episode_indices = list(range(n_envs))
print(
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
)
envs_list = [
(
lambda i=i,
task_suite=task_suite,
tasks_id=tasks_id,
_task=_task,
episode_indices=episode_indices: LiberoEnv(
task_suite=task_suite,
task_id=tasks_id,
task_suite_name=_task,
camera_name=camera_name,
init_states=init_states,
episode_index=episode_indices[i],
**gym_kwargs,
)
)
for i in range(n_envs)
]
envs[_task][tasks_id] = env_cls(envs_list)
return envs

View File

@@ -134,3 +134,49 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]
return observation
def _close_single_env(env: Any) -> None:
"""Try to close a single env object if it exposes .close()."""
try:
close_fn = getattr(env, "close", None)
if callable(close_fn):
close_fn()
except Exception as exc:
# Best-effort close: log but don't raise
LOG.debug("Exception while closing env %s: %s", env, exc)
def close_envs(env_or_collection: Any) -> None:
"""
Close a single env or any nested structure of envs.
Accepts:
- a single env with .close()
- a Mapping of things (e.g. dict)
- a Sequence of things (list/tuple) but NOT str/bytes
- nested combinations of the above
This is intentionally permissive and best-effort: it will swallow exceptions
encountered while closing individual envs and continue.
"""
# Guard: single object with close()
if hasattr(env_or_collection, "close") and not isinstance(env_or_collection, (Mapping, Sequence)):
_close_single_env(env_or_collection)
return
# Mapping (e.g., {suite: {task_id: vec_env}})
if isinstance(env_or_collection, Mapping):
for v in env_or_collection.values():
close_envs(v)
return
# Sequence (list/tuple) but skip str/bytes
if isinstance(env_or_collection, Sequence) and not isinstance(env_or_collection, (str, bytes)):
for v in env_or_collection:
close_envs(v)
return
# Fallback: try to close if possible
if hasattr(env_or_collection, "close"):
_close_single_env(env_or_collection)

View File

@@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = {
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],

View File

@@ -15,17 +15,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0.processor_pi0 import Pi0NewLineProcessor
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"PI0Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
]

View File

@@ -35,6 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.constants import ACTION, OBS_IMAGES
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -50,16 +51,27 @@ class ACTPolicy(PreTrainedPolicy):
def __init__(
self,
config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
@@ -125,19 +137,23 @@ class ACTPolicy(PreTrainedPolicy):
"""Predict a chunk of actions given environment observations."""
self.eval()
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
actions = self.model(batch)[0]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
@@ -287,7 +303,7 @@ class ACT(nn.Module):
└───────────────────────┘
"""
def __init__(self, config: ACTConfig, dataset_stats=None):
def __init__(self, config: ACTConfig):
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
super().__init__()

View File

@@ -1,51 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_act_processor(
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -35,6 +35,7 @@ from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
get_device_from_parameters,
@@ -56,6 +57,7 @@ class DiffusionPolicy(PreTrainedPolicy):
def __init__(
self,
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -68,6 +70,14 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -96,6 +106,9 @@ class DiffusionPolicy(PreTrainedPolicy):
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -124,6 +137,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
@@ -139,9 +153,11 @@ class DiffusionPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
return loss, None

View File

@@ -1,52 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_diffusion_processor(
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -14,14 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Any, TypedDict, cast
import torch
from torch import nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
@@ -32,17 +27,18 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import RobotProcessor
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@@ -68,6 +64,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "pi0_openpi":
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
return PI0OpenPIPolicy
elif name == "pi05_openpi":
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
return PI05OpenPIPolicy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -97,6 +101,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "pi0_openpi":
return PI0OpenPIConfig(**kwargs)
elif policy_type == "pi05_openpi":
return PI05OpenPIConfig(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
@@ -107,123 +115,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
raise ValueError(f"Policy type '{policy_type}' is not available.")
class ProcessorConfigKwargs(TypedDict, total=False):
"""Keyword arguments for the processor config."""
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
def make_processor(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[RobotProcessor, RobotProcessor]:
"""Make a processor instance for a given policy type.
This function creates the appropriate processor configuration based on the policy type.
Each policy type has its own processor with specific preprocessing steps.
Args:
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
the processor from this path instead of creating a new one.
**kwargs: Additional keyword arguments passed to the processor creation.
Returns:
Tuple of (input_processor, output_processor) for the policy.
Raises:
NotImplementedError: If the policy type doesn't have a processor implemented.
"""
if pretrained_path:
return (
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
overrides=kwargs.get("preprocessor_overrides", {}),
),
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
overrides=kwargs.get("postprocessor_overrides", {}),
),
)
# Create a new processor based on policy type
if policy_cfg.type == "tdmpc":
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
processors = make_tdmpc_processor(
config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "diffusion":
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
processors = make_diffusion_processor(
cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "act":
from lerobot.policies.act.processor_act import make_act_processor
processors = make_act_processor(
config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "vqbet":
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
processors = make_vqbet_processor(
config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "pi0":
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
processors = make_pi0_processor(
config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "pi0fast":
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
processors = make_pi0fast_processor(
cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "sac":
from lerobot.policies.sac.processor_sac import make_sac_processor
processors = make_sac_processor(
cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "reward_classifier":
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "smolvla":
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
processors = make_smolvla_processor(
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
return processors
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
@@ -270,6 +161,7 @@ def make_policy(
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
@@ -277,8 +169,6 @@ def make_policy(
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
if env_cfg is None:
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
@@ -296,7 +186,5 @@ def make_policy(
policy.to(cfg.device)
assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
return policy

View File

@@ -0,0 +1,420 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
assert isinstance(norm_mode, NormalizationMode)
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model."
)
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(norm_mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
but is factored out so it can be reused by both classes and stay in sync.
"""
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape: tuple[int, ...] = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# reduce spatial dimensions, keep channel dimension only
c, *_ = shape
shape = (c, 1, 1)
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_mean", mean)
module.register_buffer(f"{prefix}_std", std)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_min", min_val)
module.register_buffer(f"{prefix}_max", max_val)
continue
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
batch[key] = batch[key] * 2 - 1
continue
raise ValueError(norm_mode)
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max_val - min_val) + min_val
continue
raise ValueError(norm_mode)
return batch

View File

@@ -56,15 +56,18 @@ from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
PaliGemmaWithExpertModel,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.utils import get_safe_dtype
from lerobot.policies.utils import log_model_loading_keys
from lerobot.utils.utils import get_safe_dtype, init_logging
def create_sinusoidal_pos_embedding(
@@ -220,17 +223,28 @@ class PI0Policy(PreTrainedPolicy):
def __init__(
self,
config: PI0Config,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
@@ -239,6 +253,99 @@ class PI0Policy(PreTrainedPolicy):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
"""
Transform state dict keys to match expected model structure.
Transformations:
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
model.paligemma_with_expert.paligemma.lm_head
- model.paligemma_with_expert.paligemma.language_model.model ->
model.paligemma_with_expert.paligemma.model.language_model
- model.paligemma_with_expert.paligemma.vision_tower ->
model.paligemma_with_expert.paligemma.model.vision_tower
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
model.paligemma_with_expert.paligemma.model.multi_modal_projector
Also handles tied weights between lm_head.weight and
embed_tokens.weight.
"""
import re
transformed_dict = {}
transformations = [
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
".paligemma_with_expert.paligemma.lm_head",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
".paligemma_with_expert.paligemma.model.language_model",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
".paligemma_with_expert.paligemma.model.vision_tower",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
".paligemma_with_expert.paligemma.model.multi_modal_projector",
),
]
for key, value in state_dict.items():
new_key = key
for pattern, replacement in transformations:
new_key = pattern.sub(replacement, new_key)
transformed_dict[new_key] = value
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
lm_head_key = None
embed_tokens_key = None
for key in transformed_dict:
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
lm_head_key = key
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
embed_tokens_key = key
if lm_head_key and embed_tokens_key:
break
if lm_head_key and not embed_tokens_key:
embed_tokens_key = lm_head_key.replace(
".lm_head.weight", ".model.language_model.embed_tokens.weight"
)
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
elif embed_tokens_key and not lm_head_key:
lm_head_key = embed_tokens_key.replace(
".model.language_model.embed_tokens.weight", ".lm_head.weight"
)
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
return transformed_dict
@classmethod
def _load_as_safetensor(
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
) -> "PI0Policy":
"""Override to apply key transformations before loading."""
from safetensors.torch import load_file
init_logging()
# Load the state dict from file safely
state_dict = load_file(model_file, device=map_location)
# Apply key transformations
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
# Load the transformed state dict
msg = model.load_state_dict(transformed_state_dict, strict=strict)
# Log message
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
return model
def get_optim_params(self) -> dict:
return self.parameters()
@@ -270,13 +377,14 @@ class PI0Policy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
@@ -286,6 +394,8 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -300,10 +410,12 @@ class PI0Policy(PreTrainedPolicy):
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
@@ -370,6 +482,26 @@ class PI0Policy(PreTrainedPolicy):
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
@@ -435,7 +567,7 @@ class PI0FlowMatching(nn.Module):
└──────────────────────────────┘
"""
def __init__(self, config: PI0Config):
def __init__(self, config):
super().__init__()
self.config = config

View File

@@ -1,121 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.processor.rename_processor import RenameProcessor
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one.
This is required for the PaliGemma tokenizer.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the features."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def make_pi0_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessor(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessor(device=config.device),
]
output_steps: list[ProcessorStep] = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -0,0 +1,92 @@
# π₀.₅ (pi05)
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model with open-world generalization**.
---
### ⚠️ WARNING ⚠️
This project requires **patching the Hugging Face `transformers` library**.
1. Make sure you have the exact version installed:
```bash
pip show transformers
```
It must be version **4.53.2**.
2. Apply the custom patches by copying the modified files into your environment:
```bash
cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")
```
These patches overwrite parts of `transformers` to:
- Support the **AdaRMS optimizer**,
- Correctly control the precision of activations,
- Allow the KV cache to be used without updates.
**Important:**
- This permanently modifies your `transformers` installation.
- The changes survive reinstalls unless you explicitly remove the patched files or recreate the environment.
To undo and restore a clean state:
```bash
pip uninstall transformers
pip install transformers==4.53.2
```
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| State Embedding | Uses `state_proj` layer | No state embedding |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
year = {2025},
eprint = {2504.16054},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2504.16054},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
```
```

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,23 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from .configuration_pi05openpi import PI05OpenPIConfig
from .modeling_pi05openpi import PI05OpenPIPolicy
import numpy as np
from ..config import TeleoperatorConfig
class PhoneOS(Enum):
ANDROID = "android"
IOS = "ios"
@TeleoperatorConfig.register_subclass("phone")
@dataclass
class PhoneConfig(TeleoperatorConfig):
phone_os: PhoneOS = PhoneOS.IOS
camera_offset = np.array(
[0.0, -0.02, 0.04]
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy"]

View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("pi05_openpi")
@dataclass
class PI05OpenPIConfig(PreTrainedConfig):
# Model architecture
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
discrete_state_input: bool | None = (
True # Whether to use discrete state input # see openpi `Pi0Config, __post_init__`
)
dtype: str = "float32" # Options: "bfloat16", "float32"
# Input / output structure
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32 # State dimension (will be padded to 32)
max_action_dim: int = 32 # Action dimension (will be padded to 32)
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10 # Number of denoising steps during inference
time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling
time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling
min_period: float = 4e-3 # Min period for sinusoidal positional encoding
max_period: float = 4.0 # Max period for sinusoidal positional encoding
# Image preprocessing
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY, # Images are normalized to [-1, 1] in preprocessing
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Optimizer settings: see openpi `AdamW` and
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
# Image features are now handled dynamically through dataset configuration
# No need to auto-add hardcoded image keys
# State and action features are also handled dynamically through dataset configuration
# The actual dimensions come from the feature shapes, max dimensions are used for padding only
pass
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,173 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
class GemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
hidden_activation (`str` or `function`, *optional*):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_adarms (`bool`, *optional*, defaults to `False`):
Whether to use ADARMS.
adarms_cond_dim (`int`, *optional*, defaults to `None`):
The dimension of the ADARMS condition.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
use_adarms: bool = False,
adarms_cond_dim: int | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_adarms = use_adarms
self.adarms_cond_dim = adarms_cond_dim
# Set default for adarms_cond_dim if use_adarms is True
if self.use_adarms and self.adarms_cond_dim is None:
self.adarms_cond_dim = self.hidden_size
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["GemmaConfig"]

View File

@@ -0,0 +1,895 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
self.eps = eps
self.dim = dim
self.cond_dim = cond_dim
# Dense layer for adaptive normalization (if cond_dim is provided)
if cond_dim is not None:
# self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
# Initialize with zeros (matches source implementation)
nn.init.zeros_(self.dense.weight)
else:
self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
self.dense = None
def _norm(self, x):
# Compute variance in float32 (like the source implementation)
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
# Compute normalization in float32
normed_inputs = x * torch.rsqrt(var + self.eps)
return normed_inputs
def forward(self, x, cond=None):
dtype = x.dtype # original dtype, could be half-precision
normed_inputs = self._norm(x)
if cond is None or self.dense is None:
# regular RMSNorm
# scale by learned parameter in float32 (matches source implementation)
normed_inputs = normed_inputs * (1.0 + self.weight.float())
return normed_inputs.to(dtype), None # return in original dtype with None gate
# adaptive RMSNorm (if cond is provided and dense layer exists)
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
# self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
modulation = self.dense(cond)
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
if len(x.shape) == 3: # [batch, seq, features]
modulation = modulation.unsqueeze(1)
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
# Apply adaptive normalization: use model weight dtype to ensure compatibility
# model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
# scale = scale.to(model_dtype)
# shift = shift.to(model_dtype)
# gate = gate.to(model_dtype)
# normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
return normed_inputs.to(dtype), gate.to(dtype)
def extra_repr(self):
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
if self.dense is not None:
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
return repr_str
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, config: GemmaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _gated_residual(x, y, gate):
"""
Applies gated residual connection with optional gate parameter.
Args:
x: Input tensor (residual)
y: Output tensor to be added
gate: Optional gate tensor to modulate the addition
Returns:
x + y if gate is None, otherwise x + y * gate
"""
if x is None and y is None:
return None
if x is None or y is None:
return x if x is not None else y
if gate is None:
return x + y
return x + y * gate
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
past_key_value: Cache | None = None,
cache_position: torch.LongTensor | None = None,
use_cache: bool = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Use cache if provided
if past_key_value is not None:
if use_cache:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
else:
key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class GemmaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: None
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
residual = hidden_states
hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = _gated_residual(residual, hidden_states, gate)
# Fully Connected
residual = hidden_states
hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
hidden_states = self.mlp(hidden_states)
hidden_states = _gated_residual(residual, hidden_states, gate)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
@safe_auto_docstring
class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GemmaRMSNorm):
if hasattr(module, "weight"):
module.weight.data.fill_(1.0)
@safe_auto_docstring
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
self.rotary_emb = GemmaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
"""
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states, _ = self.norm(hidden_states, adarms_cond)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@safe_auto_docstring
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@safe_auto_docstring(
custom_intro="""
The Gemma Model transformer with a sequence classification head on top (linear layer).
[`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
)
class GemmaForSequenceClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
adarms_cond: torch.Tensor | None = None,
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
adarms_cond=adarms_cond,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config
)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@safe_auto_docstring
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
adarms_cond: torch.Tensor | None = None,
) -> TokenClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
adarms_cond=adarms_cond,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaPreTrainedModel",
]

View File

@@ -0,0 +1,666 @@
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch PaliGemmamodel."""
from dataclasses import dataclass
import torch
import torch.utils.checkpoint
from torch import nn
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
ModelOutput,
auto_docstring,
can_return_tuple,
is_torchdynamo_compiling,
logging,
)
from ..auto import AutoModel
from .configuration_paligemma import PaliGemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
@dataclass
@safe_auto_docstring(
custom_intro="""
Base class for Paligemma outputs, with hidden states and attentions.
"""
)
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: torch.FloatTensor | None = None
@dataclass
@safe_auto_docstring(
custom_intro="""
Base class for PaliGemma causal language model (or autoregressive) outputs.
"""
)
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
past_key_values: list[torch.FloatTensor] | Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
image_hidden_states: torch.FloatTensor | None = None
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(
config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True
)
def forward(self, image_features):
hidden_states = self.linear(image_features)
return hidden_states
@safe_auto_docstring
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
# inference and fine-tuning
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
)
class PaliGemmaModel(PaliGemmaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def _update_causal_mask(
self,
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
is_training: bool | None = None,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
is_training = is_training if is_training is not None else self.training
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(self.dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
return attention_mask
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=self.dtype,
device=cache_position.device,
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask[:, :sequence_length] = 0.0
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# First unmask prefix tokens during training
if is_training:
if token_type_ids is None:
raise ValueError("Token type ids must be provided during training")
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
token_type_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple | PaligemmaModelOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return PaligemmaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
)
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.model = PaliGemmaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
token_type_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> tuple | PaliGemmaCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return PaliGemmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
logits_to_keep=None,
labels=None,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
@staticmethod
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=cache_position.device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]

View File

@@ -0,0 +1,5 @@
import transformers
def check_whether_transformers_replace_is_installed_correctly():
return transformers.__version__ == "4.53.2"

View File

@@ -0,0 +1,92 @@
# π₀ (pi0)
This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action flow model for general robot control**.
---
### ⚠️ WARNING ⚠️
This project requires **patching the Hugging Face `transformers` library**.
1. Make sure you have the exact version installed:
```bash
pip show transformers
```
It must be version **4.53.2**.
2. Apply the custom patches by copying the modified files into your environment:
```bash
cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")
```
These patches overwrite parts of `transformers` to:
- Support the **AdaRMS optimizer**,
- Correctly control the precision of activations,
- Allow the KV cache to be used without updates.
**Important:**
- This permanently modifies your `transformers` installation.
- The changes survive reinstalls unless you explicitly remove the patched files or recreate the environment.
To undo and restore a clean state:
```bash
pip uninstall transformers
pip install transformers==4.53.2
```
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| State Embedding | Uses `state_proj` layer | No state embedding |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{black2024pi0visionlanguageactionflowmodel,
title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
year = {2024},
eprint = {2410.24164},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2410.24164},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
```
```

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_pi0openpi import PI0OpenPIConfig
from .modeling_pi0openpi import PI0OpenPIPolicy
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"]

View File

@@ -0,0 +1,134 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("pi0_openpi")
@dataclass
class PI0OpenPIConfig(PreTrainedConfig):
# Model architecture
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
# Input / output structure
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32 # State dimension (will be padded to 32)
max_action_dim: int = 32 # Action dimension (will be padded to 32)
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10 # Number of denoising steps during inference
time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling
time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling
min_period: float = 4e-3 # Min period for sinusoidal positional encoding
max_period: float = 4.0 # Max period for sinusoidal positional encoding
# Image preprocessing
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY, # Images are normalized to [-1, 1] in preprocessing
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Optimizer settings: see openpi `AdamW` and
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 48 # pi0=48, see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
# Image features are now handled dynamically through dataset configuration
# No need to auto-add hardcoded image keys
# State and action features are also handled dynamically through dataset configuration
# The actual dimensions come from the feature shapes, max dimensions are used for padding only
pass
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,173 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
class GemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
hidden_activation (`str` or `function`, *optional*):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_adarms (`bool`, *optional*, defaults to `False`):
Whether to use ADARMS.
adarms_cond_dim (`int`, *optional*, defaults to `None`):
The dimension of the ADARMS condition.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
use_adarms: bool = False,
adarms_cond_dim: int | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_adarms = use_adarms
self.adarms_cond_dim = adarms_cond_dim
# Set default for adarms_cond_dim if use_adarms is True
if self.use_adarms and self.adarms_cond_dim is None:
self.adarms_cond_dim = self.hidden_size
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["GemmaConfig"]

View File

@@ -0,0 +1,895 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
self.eps = eps
self.dim = dim
self.cond_dim = cond_dim
# Dense layer for adaptive normalization (if cond_dim is provided)
if cond_dim is not None:
# self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
# Initialize with zeros (matches source implementation)
nn.init.zeros_(self.dense.weight)
else:
self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
self.dense = None
def _norm(self, x):
# Compute variance in float32 (like the source implementation)
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
# Compute normalization in float32
normed_inputs = x * torch.rsqrt(var + self.eps)
return normed_inputs
def forward(self, x, cond=None):
dtype = x.dtype # original dtype, could be half-precision
normed_inputs = self._norm(x)
if cond is None or self.dense is None:
# regular RMSNorm
# scale by learned parameter in float32 (matches source implementation)
normed_inputs = normed_inputs * (1.0 + self.weight.float())
return normed_inputs.to(dtype), None # return in original dtype with None gate
# adaptive RMSNorm (if cond is provided and dense layer exists)
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
# self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
modulation = self.dense(cond)
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
if len(x.shape) == 3: # [batch, seq, features]
modulation = modulation.unsqueeze(1)
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
# Apply adaptive normalization: use model weight dtype to ensure compatibility
# model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
# scale = scale.to(model_dtype)
# shift = shift.to(model_dtype)
# gate = gate.to(model_dtype)
# normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
return normed_inputs.to(dtype), gate.to(dtype)
def extra_repr(self):
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
if self.dense is not None:
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
return repr_str
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, config: GemmaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _gated_residual(x, y, gate):
"""
Applies gated residual connection with optional gate parameter.
Args:
x: Input tensor (residual)
y: Output tensor to be added
gate: Optional gate tensor to modulate the addition
Returns:
x + y if gate is None, otherwise x + y * gate
"""
if x is None and y is None:
return None
if x is None or y is None:
return x if x is not None else y
if gate is None:
return x + y
return x + y * gate
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
past_key_value: Cache | None = None,
cache_position: torch.LongTensor | None = None,
use_cache: bool = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Use cache if provided
if past_key_value is not None:
if use_cache:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
else:
key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class GemmaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: None
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
residual = hidden_states
hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = _gated_residual(residual, hidden_states, gate)
# Fully Connected
residual = hidden_states
hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
hidden_states = self.mlp(hidden_states)
hidden_states = _gated_residual(residual, hidden_states, gate)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
@safe_auto_docstring
class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GemmaRMSNorm):
if hasattr(module, "weight"):
module.weight.data.fill_(1.0)
@safe_auto_docstring
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
self.rotary_emb = GemmaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
"""
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states, _ = self.norm(hidden_states, adarms_cond)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@safe_auto_docstring
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@safe_auto_docstring(
custom_intro="""
The Gemma Model transformer with a sequence classification head on top (linear layer).
[`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
)
class GemmaForSequenceClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
adarms_cond: torch.Tensor | None = None,
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
adarms_cond=adarms_cond,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config
)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@safe_auto_docstring
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
adarms_cond: torch.Tensor | None = None,
) -> TokenClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
adarms_cond=adarms_cond,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaPreTrainedModel",
]

View File

@@ -0,0 +1,666 @@
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch PaliGemmamodel."""
from dataclasses import dataclass
import torch
import torch.utils.checkpoint
from torch import nn
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
ModelOutput,
auto_docstring,
can_return_tuple,
is_torchdynamo_compiling,
logging,
)
from ..auto import AutoModel
from .configuration_paligemma import PaliGemmaConfig
logger = logging.get_logger(__name__)
# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring
def safe_auto_docstring(func=None, **kwargs):
"""Auto docstring decorator that handles Python 3.10+ UnionType gracefully."""
def decorator(f):
try:
return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f)
except (AttributeError, TypeError):
# If auto_docstring fails due to UnionType, just return the function unchanged
return f
if func is None:
# Called with arguments, return the decorator
return decorator
else:
# Called without arguments, apply directly
return decorator(func)
@dataclass
@safe_auto_docstring(
custom_intro="""
Base class for Paligemma outputs, with hidden states and attentions.
"""
)
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: torch.FloatTensor | None = None
@dataclass
@safe_auto_docstring(
custom_intro="""
Base class for PaliGemma causal language model (or autoregressive) outputs.
"""
)
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
past_key_values: list[torch.FloatTensor] | Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
image_hidden_states: torch.FloatTensor | None = None
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(
config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True
)
def forward(self, image_features):
hidden_states = self.linear(image_features)
return hidden_states
@safe_auto_docstring
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
# inference and fine-tuning
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
)
class PaliGemmaModel(PaliGemmaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def _update_causal_mask(
self,
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
is_training: bool | None = None,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
is_training = is_training if is_training is not None else self.training
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(self.dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
return attention_mask
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=self.dtype,
device=cache_position.device,
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask[:, :sequence_length] = 0.0
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# First unmask prefix tokens during training
if is_training:
if token_type_ids is None:
raise ValueError("Token type ids must be provided during training")
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
token_type_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple | PaligemmaModelOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return PaligemmaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@safe_auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
)
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.model = PaliGemmaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@safe_auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
token_type_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> tuple | PaliGemmaCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return PaliGemmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
logits_to_keep=None,
labels=None,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
@staticmethod
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=cache_position.device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]

View File

@@ -0,0 +1,5 @@
import transformers
def check_whether_transformers_replace_is_installed_correctly():
return transformers.__version__ == "4.53.2"

View File

@@ -58,6 +58,7 @@ from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -145,6 +146,14 @@ class PI0FASTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
@@ -212,6 +221,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
@@ -224,6 +235,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -236,6 +249,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict

View File

@@ -1,52 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_pi0fast_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -28,6 +28,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.policies.normalize import NormalizeBuffer
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.policies.utils import get_device_from_parameters
@@ -44,6 +45,7 @@ class SACPolicy(
def __init__(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__(config)
config.validate_features()
@@ -51,6 +53,7 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features["action"].shape[0]
self._init_normalization(dataset_stats)
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
@@ -85,7 +88,8 @@ class SACPolicy(
observations_features = None
if self.shared_encoder and self.actor.encoder.has_images:
observations_features = self.actor.encoder.get_cached_image_features(batch)
# Cache and normalize image features
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features)
@@ -387,12 +391,28 @@ class SACPolicy(
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_normalization(self, dataset_stats):
"""Initialize input/output normalization modules."""
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
if self.config.dataset_stats is not None:
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
self.normalize_inputs = NormalizeBuffer(
self.config.input_features, self.config.normalization_mapping, params
)
stats = dataset_stats or params
self.normalize_targets = NormalizeBuffer(
self.config.output_features, self.config.normalization_mapping, stats
)
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
self.encoder_critic
if self.shared_encoder
else SACObservationEncoder(self.config, self.normalize_inputs)
)
def _init_critics(self, continuous_action_dim):
@@ -404,7 +424,9 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
self.critic_ensemble = CriticEnsemble(
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
@@ -412,7 +434,9 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target = CriticEnsemble(
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
@@ -466,9 +490,10 @@ class SACPolicy(
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig) -> None:
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self._init_image_layers()
self._init_state_layers()
self._compute_output_dim()
@@ -543,10 +568,11 @@ class SACObservationEncoder(nn.Module):
def forward(
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
) -> Tensor:
obs = self.input_normalization(obs)
parts = []
if self.has_images:
if cache is None:
cache = self.get_cached_image_features(obs)
cache = self.get_cached_image_features(obs, normalize=False)
parts.append(self._encode_images(cache, detach))
if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"]))
@@ -559,7 +585,7 @@ class SACObservationEncoder(nn.Module):
"No parts to concatenate, you should have at least one image or environment state or state"
)
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
"""Extract and optionally cache image features from observations.
This function processes image observations through the vision encoder once and returns
@@ -571,17 +597,26 @@ class SACObservationEncoder(nn.Module):
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
- Caching these features can provide 2-4x speedup in training and inference
Normalization behavior:
- When called from inside forward(): set normalize=False since inputs are already normalized
- When called from outside forward(): set normalize=True to ensure proper input normalization
Usage patterns:
- Called in select_action()
- Called in select_action() with normalize=True
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
- Called internally by forward()
- Called internally by forward() with normalize=False
Args:
obs: Dictionary of observation tensors containing image keys
normalize: Whether to normalize observations before encoding
Set to True when calling directly from outside the encoder's forward method
Set to False when calling from within forward() where inputs are already normalized
Returns:
Dictionary mapping image keys to their corresponding encoded features
"""
if normalize:
obs = self.input_normalization(obs)
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
out = self.image_encoder(batched)
chunks = torch.chunk(out, len(self.image_keys), dim=0)
@@ -712,6 +747,7 @@ class CriticEnsemble(nn.Module):
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
output_normalization (nn.Module): normalization layer for actions.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
@@ -721,11 +757,13 @@ class CriticEnsemble(nn.Module):
self,
encoder: SACObservationEncoder,
ensemble: list[CriticHead],
output_normalization: nn.Module,
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
def forward(
@@ -737,6 +775,11 @@ class CriticEnsemble(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = self.encoder(observations, cache=observation_features)

View File

@@ -1,53 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_sac_processor(
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -20,6 +20,7 @@ import torch
from torch import Tensor, nn
from lerobot.constants import OBS_IMAGE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -107,12 +108,22 @@ class Classifier(PreTrainedPolicy):
def __init__(
self,
config: RewardClassifierConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
@@ -236,6 +247,10 @@ class Classifier(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)

View File

@@ -1,42 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.processor import (
DeviceProcessor,
IdentityProcessor,
NormalizerProcessor,
RobotProcessor,
)
def make_classifier_processor(
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessor(device=config.device),
]
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
steps=output_steps, name="classifier_postprocessor"
)

View File

@@ -53,13 +53,21 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
"""
import math
import os
import re
from collections import deque
import safetensors
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
@@ -68,6 +76,102 @@ from lerobot.policies.utils import (
)
from lerobot.utils.utils import get_safe_dtype
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
normalisation-buffer key.
"""
return _VARIANT_RE.sub(".buffer_", k)
def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]:
"""
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
• If several variant keys collapse to the same canonical name we keep the
first one and log the collision.
• Returns the new dict + a list of entries that could not be matched.
"""
out, collisions, unmatched = {}, {}, []
for k, v in checkpoint.items():
canon = canonicalise(k)
if canon in ref_keys:
if canon in out: # duplicate after collapsing
collisions.setdefault(canon, []).append(k)
else:
out[canon] = v
else:
unmatched.append(k)
if verbose:
for canon, variants in collisions.items():
print(f"[standardise_state_dict] '{canon}'{variants}")
if unmatched:
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.
Args:
checkpoint (dict): The checkpoint dictionary.
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
Returns:
dict: The modified checkpoint with renamed keys.
"""
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
new_checkpoint = {}
for k, v in checkpoint.items():
for old_key, new_key in rename_dict.items():
if old_key in k:
k = k.replace(old_key, new_key)
new_checkpoint[k] = v
return new_checkpoint
def load_smolvla(
model: torch.nn.Module,
filename: str | os.PathLike,
*,
device: str = "cpu",
checkpoint_keys_mapping: str = "",
) -> torch.nn.Module:
state_dict = safetensors.torch.load_file(filename, device=device)
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
raise RuntimeError(
"SmolVLA %d missing / %d unexpected keys",
len(missing),
len(unexpected),
)
return model
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
@@ -222,17 +326,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
def __init__(
self,
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
@@ -242,6 +357,23 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
@classmethod
def _load_as_safetensor(
cls,
model: "SmolVLAPolicy",
model_file: str,
map_location: str,
strict: bool,
):
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return load_smolvla(
model,
model_file,
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
)
def get_optim_params(self) -> dict:
return self.parameters()
@@ -257,8 +389,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
@@ -266,6 +397,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -275,6 +408,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
return batch
@torch.no_grad()
@@ -315,11 +450,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
@@ -383,6 +518,30 @@ class SmolVLAPolicy(PreTrainedPolicy):
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:

View File

@@ -1,110 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
def make_smolvla_processor(
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
SmolVLANewLineProcessor(),
TokenizerProcessor(
tokenizer_name=config.vlm_model_name,
padding=config.pad_language_to,
padding_side="right",
max_length=config.tokenizer_max_length,
),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Adds nothing to the features."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}

View File

@@ -36,6 +36,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
@@ -62,19 +63,26 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig
name = "tdmpc"
def __init__(
self,
config: TDMPCConfig,
):
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
@@ -129,6 +137,7 @@ class TDMPCPolicy(PreTrainedPolicy):
actions = torch.clamp(actions, -1, +1)
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -138,12 +147,11 @@ class TDMPCPolicy(PreTrainedPolicy):
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch)
@@ -312,9 +320,11 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}

View File

@@ -1,52 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_tdmpc_processor(
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -28,6 +28,7 @@ import torchvision
from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -47,6 +48,7 @@ class VQBeTPolicy(PreTrainedPolicy):
def __init__(
self,
config: VQBeTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -59,6 +61,14 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.vqbet = VQBeTModel(config)
self.reset()
@@ -118,6 +128,7 @@ class VQBeTPolicy(PreTrainedPolicy):
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -131,12 +142,10 @@ class VQBeTPolicy(PreTrainedPolicy):
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch)
@@ -156,8 +165,10 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
# loss: total loss of training RVQ

View File

@@ -1,53 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_vqbet_processor(
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -14,22 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .batch_processor import ToBatchProcessor
from .delta_action_processor import MapDeltaActionToRobotAction
from .device_processor import DeviceProcessor
from .hil_processor import (
AddTeleopActionAsComplimentaryData,
AddTeleopEventsAsInfo,
GripperPenaltyProcessor,
ImageCropResizeProcessor,
InterventionActionProcessor,
Numpy2TorchActionProcessor,
RewardClassifierProcessor,
TimeLimitProcessor,
Torch2NumpyActionProcessor,
)
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import VanillaObservationProcessor
from .pipeline import (
ActionProcessor,
@@ -46,39 +32,22 @@ from .pipeline import (
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
from .tokenizer_processor import TokenizerProcessor
__all__ = [
"ActionProcessor",
"AddTeleopActionAsComplimentaryData",
"AddTeleopEventsAsInfo",
"DeviceProcessor",
"DoneProcessor",
"MapDeltaActionToRobotAction",
"EnvTransition",
"GripperPenaltyProcessor",
"IdentityProcessor",
"ImageCropResizeProcessor",
"InfoProcessor",
"InterventionActionProcessor",
"JointVelocityProcessor",
"MapDeltaActionToRobotAction",
"MotorCurrentProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
"hotswap_stats",
"ObservationProcessor",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",
"RewardClassifierProcessor",
"RewardProcessor",
"RobotProcessor",
"ToBatchProcessor",
"TokenizerProcessor",
"TimeLimitProcessor",
"Numpy2TorchActionProcessor",
"Torch2NumpyActionProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",

View File

@@ -1,139 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class ToBatchProcessor:
"""Processor that adds batch dimensions to observations and actions when needed.
This processor ensures that observations and actions have proper batch dimensions for model processing:
- For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For task field in complementary data:
Wraps string task in a list to add batch dimension
(task must be a string or list of strings)
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
The processor only modifies tensors that need batching and leaves already
batched tensors unchanged.
Example:
```python
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Task: "pick_cube" -> ["pick_cube"]
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._process_observation(transition)
self._process_action(transition)
self._process_complementary_data(transition)
return transition
def _process_observation(self, transition: EnvTransition) -> None:
"""Process observation component in-place, adding batch dimensions where needed."""
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
state_value = observation[state_key]
if isinstance(state_value, Tensor) and state_value.dim() == 1:
observation[state_key] = state_value.unsqueeze(0)
# Process single image observation - add batch dim if 3D
if OBS_IMAGE in observation:
image_value = observation[OBS_IMAGE]
if isinstance(image_value, Tensor) and image_value.dim() == 3:
observation[OBS_IMAGE] = image_value.unsqueeze(0)
# Process multiple image observations - add batch dim if 3D
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
def _process_action(self, transition: EnvTransition) -> None:
"""Process action component in-place, adding batch dimension if needed."""
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
transition[TransitionKey.ACTION] = action.unsqueeze(0)
def _process_complementary_data(self, transition: EnvTransition) -> None:
"""Process complementary data in-place, handling task field batching."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
# Process index field - add batch dim if 0D
if "index" in complementary_data:
index_value = complementary_data["index"]
if isinstance(index_value, Tensor) and index_value.dim() == 0:
complementary_data["index"] = index_value.unsqueeze(0)
# Process task_index field - add batch dim if 0D
if "task_index" in complementary_data:
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,225 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Sequence
from copy import deepcopy
from typing import Any
import numpy as np
import torch
from scipy.spatial.transform import Rotation
from .pipeline import EnvTransition, TransitionKey
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
if isinstance(x, torch.Tensor):
return x
if isinstance(x, np.ndarray):
# Keep images (uint8 HWC) and python objects as-is
if x.dtype == np.uint8 or x.dtype == np.object_:
return x
# Scalars/arrays to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
# Anything else to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
def _from_tensor(x: Any):
if isinstance(x, torch.Tensor):
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
return x
def _is_image(arr: Any) -> bool:
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
state, images = {}, {}
for k, v in obs.items():
if _is_image(v):
images[k] = v
else:
state[k] = v
return state, images
def make_obs_act_transition(
*, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None
) -> EnvTransition:
return {
TransitionKey.OBSERVATION: {} if obs is None else obs,
TransitionKey.ACTION: {} if act is None else act,
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {},
TransitionKey.REWARD: None,
TransitionKey.DONE: None,
TransitionKey.TRUNCATED: None,
}
def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
"""
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
"""
act_dict: dict[str, Any] = {}
for k, v in action.items():
# Check if the value is a type that should not be converted to a tensor.
if isinstance(v, (Rotation, dict)):
act_dict[f"action.{k}"] = v
continue
arr = np.array(v) if np.isscalar(v) else v
act_dict[f"action.{k}"] = _to_tensor(arr)
return make_obs_act_transition(act=act_dict)
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition:
"""
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
"""
state, images = _split_obs_to_state_and_images(observation)
obs_dict: dict[str, Any] = {}
for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"observation.state.{k}"] = _to_tensor(arr)
for cam, img in images.items():
obs_dict[f"observation.images.{cam}"] = img
return make_obs_act_transition(obs=obs_dict)
def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
"""
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
"""
out: dict[str, Any] = {}
action_dict = transition.get(TransitionKey.ACTION) or {}
for k, v in action_dict.items():
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")):
out_key = k[len("action.") :] # Strip the 'action.' prefix.
out[out_key] = float(v)
return out
def to_dataset_frame(
transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict]
) -> dict[str, any]:
"""
Converts a single EnvTransition or an iterable of them into a flat,
dataset-friendly dictionary for training or evaluation, according to
the provided `features` spec.
Args:
transitions_or_transition: Either a single EnvTransition dict
or an iterable of them (which will be merged).
features (dict[str, dict]):
A feature specification dictionary:
- 'action': dict with 'names': list of action feature names
- 'observation.state': dict with 'names': list of state feature names
- keys starting with 'observation.images.' are passed through
Returns:
batch (dict[str, any]): Flat dictionary containing:
- numpy arrays for "observation.state" and "action"
- any image tensors defined in features
- next.{reward,done,truncated}
- info dict
- *_is_pad flags and task from complementary_data
"""
action_names = features.get("action", {}).get("names", [])
obs_state_names = features.get("observation.state", {}).get("names", [])
image_keys = [k for k in features if k.startswith("observation.images.")]
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
out = deepcopy(base)
for key in (
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
):
if other.get(key):
out.setdefault(key, {}).update(deepcopy(other[key]))
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
if k in other:
out[k] = other[k]
return out
def _ensure_transition(obj) -> EnvTransition:
# single transition
if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj):
return obj
# iterable of transitions
if isinstance(obj, Iterable):
items = list(obj)
if not items:
return {}
acc = items[0]
for t in items[1:]:
acc = _merge(acc, t)
return acc
raise TypeError("Expected EnvTransition or iterable of them")
tr = _ensure_transition(transitions_or_transition)
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
act = tr.get(TransitionKey.ACTION, {}) or {}
batch: dict[str, any] = {}
# Images passthrough
for k in image_keys:
if k in obs:
batch[k] = obs[k]
# Observation.state vector
if obs_state_names:
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names]
batch["observation.state"] = np.asarray(vals, dtype=np.float32)
# Action vector
if action_names:
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names]
batch["action"] = np.asarray(vals, dtype=np.float32)
# Next.* fields
if tr.get(TransitionKey.REWARD) is not None:
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD])
if tr.get(TransitionKey.DONE) is not None:
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE])
if tr.get(TransitionKey.TRUNCATED) is not None:
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED])
# Complementary data flags and task
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if comp:
# pad flags
for k, v in comp.items():
if k.endswith("_is_pad"):
batch[k] = v
# task label
if comp.get("task") is not None:
batch["task"] = comp["task"]
return batch

View File

@@ -1,125 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
@dataclass
class MapDeltaActionToRobotAction(ActionProcessor):
"""
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
for use with inverse kinematics processors.
Expected input ACTION keys:
{
"action.delta_x": float,
"action.delta_y": float,
"action.delta_z": float,
"action.gripper": float (optional),
}
Output ACTION keys:
{
"action.enabled": bool,
"action.target_x": float,
"action.target_y": float,
"action.target_z": float,
"action.target_wx": float,
"action.target_wy": float,
"action.target_wz": float,
"action.gripper": float,
}
"""
# Scale factors for delta movements
position_scale: float = 1.0
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
gripper_deadzone: float = 0.1 # Threshold for gripper activation
_prev_enabled: bool = field(default=False, init=False, repr=False)
def action(self, action: dict | Tensor | None) -> dict:
if action is None:
return {}
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
if isinstance(action, dict):
delta_x = action.pop("action.delta_x", 0.0)
delta_y = action.pop("action.delta_y", 0.0)
delta_z = action.pop("action.delta_z", 0.0)
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
else:
delta_x = action[0].item()
delta_y = action[1].item()
delta_z = action[2].item()
gripper = action[3].item()
# Determine if the teleoperator is actively providing input
# Consider enabled if any significant movement delta is detected
position_magnitude = abs(delta_x) + abs(delta_y) + abs(delta_z)
enabled = position_magnitude > 1e-6 # Small threshold to avoid noise
# Scale the deltas appropriately
scaled_delta_x = float(delta_x) * self.position_scale
scaled_delta_y = float(delta_y) * self.position_scale
scaled_delta_z = float(delta_z) * self.position_scale
# For gamepad/keyboard, we don't have rotation input, so set to 0
# These could be extended in the future for more sophisticated teleoperators
target_wx = 0.0
target_wy = 0.0
target_wz = 0.0
# Update action with robot target format
action = {
"action.enabled": enabled,
"action.target_x": scaled_delta_x,
"action.target_y": scaled_delta_y,
"action.target_z": scaled_delta_z,
"action.target_wx": target_wx,
"action.target_wy": target_wy,
"action.target_wz": target_wz,
"action.gripper": float(gripper),
}
self._prev_enabled = enabled
return action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transform features to match output format."""
# Update features to reflect the new action format
features.update(
{
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
}
)
return features
def reset(self):
self._prev_enabled = False

View File

@@ -19,80 +19,24 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
"""Processes transitions by moving tensors to the specified device.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned. It can also convert
floating-point tensors to a specified dtype while preserving non-float types
(int, long, bool, etc.).
specified device (CPU or GPU) before they are returned.
"""
device: str = "cpu"
float_dtype: str | None = None
_device: torch.device | None = None
device: torch.device = "cpu"
def __post_init__(self):
self._device = get_safe_torch_device(self.device)
self.device = self._device.type
self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
if self.float_dtype is not None:
dtype_mapping = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}
if self.float_dtype not in dtype_mapping:
available_dtypes = list(dtype_mapping.keys())
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
)
self._target_float_dtype = dtype_mapping[self.float_dtype]
else:
self._target_float_dtype = None
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process a tensor by moving to device and optionally converting float dtype.
If the tensor is already on a GPU and we're configured for a GPU, it preserves
that GPU placement (useful for multi-GPU training with Accelerate).
Otherwise, it moves to the configured device.
"""
# Determine target device
if tensor.is_cuda and self._device.type == "cuda":
# Both tensor and target are on GPU - preserve tensor's GPU placement
# This handles multi-GPU scenarios where Accelerate has already placed
# tensors on the correct GPU for each process
target_device = tensor.device
else:
# Either tensor is on CPU, or we're configured for CPU
# In both cases, use the configured device
target_device = self._device
# Only move if necessary
if tensor.device != target_device:
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
# Convert float dtype if specified and tensor is floating point
if self._target_float_dtype is not None and tensor.is_floating_point():
tensor = tensor.to(dtype=self._target_float_dtype)
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
@@ -101,7 +45,7 @@ class DeviceProcessor:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
@@ -109,54 +53,30 @@ class DeviceProcessor:
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = self._process_tensor(action)
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = self._process_tensor(done)
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)
# Process complementary data tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is not None:
new_complementary_data = {}
# Process all items in complementary_data
for key, value in complementary_data.items():
if isinstance(value, torch.Tensor):
new_complementary_data[key] = self._process_tensor(value)
else:
new_complementary_data[key] = value
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device, "float_dtype": self.float_dtype}
return {"device": self.device}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,418 +0,0 @@
import time
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
InfoProcessor,
ObservationProcessor,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
GRIPPER_KEY = "gripper"
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@dataclass
class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
"""Add teleoperator action to transition complementary data."""
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict | None) -> dict:
complementary_data = {} if complementary_data is None else dict(complementary_data)
complementary_data["teleop_action"] = self.teleop_device.get_action()
return complementary_data
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
class AddTeleopEventsAsInfo(InfoProcessor):
"""Add teleoperator control events to transition info."""
teleop_device: Teleoperator
def info(self, info: dict | None) -> dict:
info = {} if info is None else dict(info)
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
info.update(teleop_events)
return info
@ProcessorStepRegistry.register("torch2numpy_action_processor")
@dataclass
class Torch2NumpyActionProcessor(ActionProcessor):
"""Convert PyTorch tensor actions to NumPy arrays."""
squeeze_batch_dim: bool = True
def action(self, action: torch.Tensor | None) -> np.ndarray | None:
if action is None:
return None
if not isinstance(action, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
numpy_action = action.detach().cpu().numpy()
# Remove batch dimensions but preserve action dimensions
# Only squeeze if there's a batch dimension (first dim == 1)
if (
self.squeeze_batch_dim
and numpy_action.shape
and len(numpy_action.shape) > 1
and numpy_action.shape[0] == 1
):
numpy_action = numpy_action.squeeze(0)
return numpy_action
@ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass
class Numpy2TorchActionProcessor(ActionProcessor):
"""Convert NumPy array action to PyTorch tensor."""
def action(self, action: np.ndarray | None) -> torch.Tensor | None:
if action is None:
return None
if not isinstance(action, np.ndarray):
raise TypeError(
f"Expected np.ndarray or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
torch_action = torch.from_numpy(action)
return torch_action
@ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass
class ImageCropResizeProcessor(ObservationProcessor):
"""Crop and resize image observations."""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
if self.resize_size is None and not self.crop_params_dict:
return observation
new_observation = dict(observation)
# Process all image keys in the observation
for key in observation:
if "image" not in key:
continue
image = observation[key]
device = image.device
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
if device.type == "mps":
image = image.cpu()
# Crop if crop params are provided for this key
if self.crop_params_dict is not None and key in self.crop_params_dict:
crop_params = self.crop_params_dict[key]
image = F.crop(image, *crop_params)
if self.resize_size is not None:
image = F.resize(image, self.resize_size)
image = image.clamp(0.0, 1.0)
new_observation[key] = image.to(device)
return new_observation
def get_config(self) -> dict[str, Any]:
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if self.resize_size is None:
return features
for key in features:
if "image" in key:
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
return features
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessor:
"""Track episode steps and enforce time limits."""
max_episode_steps: int
current_step: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"max_episode_steps": self.max_episode_steps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessor:
"""Apply penalty for inappropriate gripper usage."""
penalty: float = -0.01
max_gripper_pos: float = 30.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Calculate gripper penalty and add to complementary data."""
action = transition.get(TransitionKey.ACTION)
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or action is None:
return transition
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return transition
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Add penalty information to complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data["discrete_penalty"] = gripper_penalty
# Create new transition with updated complementary data
new_transition = transition.copy()
existing_comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
existing_comp_data.update(new_complementary_data)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = existing_comp_data # type: ignore[misc]
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessor:
"""Handle human intervention actions and episode termination."""
use_gripper: bool = False
terminate_on_success: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
# Get intervention signals from complementary data
info = transition.get(TransitionKey.INFO, {})
teleop_action = info.get("teleop_action", {})
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
success = info.get(TeleopEvents.SUCCESS, False)
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
new_transition = transition.copy()
# Override action if intervention is active
if is_intervention and teleop_action is not None:
if isinstance(teleop_action, dict):
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get("action.delta_x", 0.0),
teleop_action.get("action.delta_y", 0.0),
teleop_action.get("action.delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get("gripper", 1.0))
elif isinstance(teleop_action, np.ndarray):
action_list = teleop_action.tolist()
else:
action_list = teleop_action
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
new_transition[TransitionKey.ACTION] = teleop_action_tensor
# Handle episode termination
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
self.terminate_on_success and success
)
new_transition[TransitionKey.REWARD] = float(success)
# Update info with intervention metadata
info = new_transition.get(TransitionKey.INFO, {})
info[TeleopEvents.IS_INTERVENTION] = is_intervention
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
info[TeleopEvents.SUCCESS] = success
new_transition[TransitionKey.INFO] = info
# Update complementary data with teleop action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"use_gripper": self.use_gripper,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
class RewardClassifierProcessor:
"""Apply reward classification to image observations."""
pretrained_path: str | None = None
device: str = "cpu"
success_threshold: float = 0.5
success_reward: float = 1.0
terminate_on_success: bool = True
reward_classifier: Any = None
def __post_init__(self):
"""Initialize the reward classifier after dataclass initialization."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
self.reward_classifier.to(self.device)
self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None:
return transition
# Extract images from observation
images = {key: value for key, value in observation.items() if "image" in key}
if not images:
return transition
# Run reward classifier
start_time = time.perf_counter()
with torch.inference_mode():
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
classifier_frequency = 1 / (time.perf_counter() - start_time)
# Calculate reward and termination
reward = transition.get(TransitionKey.REWARD, 0.0)
terminated = transition.get(TransitionKey.DONE, False)
if success == 1.0:
reward = self.success_reward
if self.terminate_on_success:
terminated = True
# Update transition
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward
new_transition[TransitionKey.DONE] = terminated
# Update info with classifier frequency
info = new_transition.get(TransitionKey.INFO, {})
info["reward_classifier_frequency"] = classifier_frequency
new_transition[TransitionKey.INFO] = info
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"device": self.device,
"success_threshold": self.success_threshold,
"success_reward": self.success_reward,
"terminate_on_success": self.terminate_on_success,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,116 +0,0 @@
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
from lerobot.robots import Robot
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor:
"""Add joint velocity information to observations."""
joint_velocity_limits: float = 100.0
dt: float = 1.0 / 10
num_dof: int | None = None
last_joint_positions: torch.Tensor | None = None
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get("observation.state")
if current_positions is None:
return observation
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
self.last_joint_positions = current_positions.clone()
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
self.last_joint_positions = current_positions.clone()
# Extend observation with velocities
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
return new_observation
def get_config(self) -> dict[str, Any]:
return {
"joint_velocity_limits": self.joint_velocity_limits,
"dt": self.dt,
}
def reset(self) -> None:
self.last_joint_positions = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if "observation.state" in features and self.num_dof is not None:
from lerobot.configs.types import PolicyFeature
original_feature = features["observation.state"]
# Double the shape to account for positions + velocities
new_shape = (original_feature.shape[0] + self.num_dof,) + original_feature.shape[1:]
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
return features
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessor(ObservationProcessor):
"""Add motor current information to observations."""
robot: Robot | None = None
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
# Get current values from robot state
if self.robot is None:
return observation
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
motor_currents = torch.tensor(
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
dtype=torch.float32,
).unsqueeze(0)
current_state = observation.get("observation.state")
if current_state is None:
return observation
extended_state = torch.cat([current_state, motor_currents], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
return new_observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if "observation.state" in features and self.robot is not None:
from lerobot.configs.types import PolicyFeature
original_feature = features["observation.state"]
# Add motor current dimensions to the original state shape
num_motors = 0
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
if num_motors > 0:
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
return features

View File

@@ -1,502 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
This script:
1. Loads an existing pretrained policy model
2. Extracts normalization statistics from the model
3. Creates both preprocessor and postprocessor:
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
- Postprocessor: unnormalizes outputs (actions) for inference
4. Removes normalization layers from the model state_dict
5. Saves the new model and both processors
Usage:
python src/lerobot/processor/migrate_policy_normalization.py \
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
--policy-type act \
--push-to-hub
"""
import argparse
import importlib
import json
import os
from copy import deepcopy
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.device_processor import DeviceProcessor
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor.rename_processor import RenameProcessor
# Policy type to class mapping
POLICY_CLASSES = {
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
}
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Extract normalization statistics from model state_dict."""
stats = {}
# Define patterns to match and their prefixes to remove
normalization_patterns = [
"normalize_inputs.buffer_",
"unnormalize_outputs.buffer_",
"normalize_targets.buffer_",
"normalize.", # Must come after normalize_* patterns
"unnormalize.", # Must come after unnormalize_* patterns
"input_normalizer.",
"output_normalizer.",
]
# Process each key in state_dict
for key, tensor in state_dict.items():
# Try each pattern
for pattern in normalization_patterns:
if key.startswith(pattern):
# Extract the remaining part after the pattern
remaining = key[len(pattern) :]
parts = remaining.split(".")
# Need at least feature name and stat type
if len(parts) >= 2:
# Last part is the stat type (mean, std, min, max, etc.)
stat_type = parts[-1]
# Everything else is the feature name
feature_name = ".".join(parts[:-1]).replace("_", ".")
# Add to stats
if feature_name not in stats:
stats[feature_name] = {}
stats[feature_name][stat_type] = tensor.clone()
# Only process the first matching pattern
break
return stats
def detect_features_and_norm_modes(
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
"""Detect features and normalization modes from config and stats."""
features = {}
norm_modes = {}
# First, check if there's a normalization_mapping in the config
if "normalization_mapping" in config:
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
# Extract normalization modes from config
for feature_name, mode_str in config["normalization_mapping"].items():
# Convert string to NormalizationMode enum
if mode_str == "mean_std":
mode = NormalizationMode.MEAN_STD
elif mode_str == "min_max":
mode = NormalizationMode.MIN_MAX
else:
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
continue
# Determine feature type from feature name
if "image" in feature_name or "visual" in feature_name:
feature_type = FeatureType.VISUAL
elif "state" in feature_name:
feature_type = FeatureType.STATE
elif "action" in feature_name:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
norm_modes[feature_type] = mode
# Try to extract from config
if "features" in config:
for key, feature_config in config["features"].items():
shape = feature_config.get("shape", feature_config.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
# Determine feature type
if "image" in key or "visual" in key:
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE # Default
features[key] = PolicyFeature(feature_type, shape)
# If no features in config, infer from stats
if not features:
for key, stat_dict in stats.items():
# Get shape from any stat tensor
tensor = next(iter(stat_dict.values()))
shape = tuple(tensor.shape)
# Determine feature type based on key
if "image" in key or "visual" in key or "pixels" in key:
feature_type = FeatureType.VISUAL
elif "state" in key or "joint" in key or "position" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
features[key] = PolicyFeature(feature_type, shape)
# If normalization modes weren't in config, determine based on available stats
if not norm_modes:
for key, stat_dict in stats.items():
if key in features:
if "mean" in stat_dict and "std" in stat_dict:
feature_type = features[key].type
if feature_type not in norm_modes:
norm_modes[feature_type] = NormalizationMode.MEAN_STD
elif "min" in stat_dict and "max" in stat_dict:
feature_type = features[key].type
if feature_type not in norm_modes:
norm_modes[feature_type] = NormalizationMode.MIN_MAX
# Default normalization modes if not detected
if FeatureType.VISUAL not in norm_modes:
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
if FeatureType.STATE not in norm_modes:
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
if FeatureType.ACTION not in norm_modes:
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
return features, norm_modes
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Remove normalization layers from state_dict."""
new_state_dict = {}
# Patterns to remove
remove_patterns = [
"normalize_inputs.",
"unnormalize_outputs.",
"normalize_targets.", # Added pattern for target normalization
"normalize.",
"unnormalize.",
"input_normalizer.",
"output_normalizer.",
"normalizer.",
]
for key, tensor in state_dict.items():
should_remove = any(pattern in key for pattern in remove_patterns)
if not should_remove:
new_state_dict[key] = tensor
return new_state_dict
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert features from old format to PolicyFeature objects."""
converted_features = {}
for key, feature_dict in features_dict.items():
# Determine feature type based on key
if "image" in key or "visual" in key:
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
# Get shape from feature dict
shape = feature_dict.get("shape", feature_dict.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
converted_features[key] = PolicyFeature(feature_type, shape)
return converted_features
def load_model_from_hub(
repo_id: str, revision: str = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""Load model state_dict and config from hub."""
# Download files
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
# Load state_dict
state_dict = load_safetensors(safetensors_path)
# Load config
with open(config_path) as f:
config = json.load(f)
with open(train_config_path) as f:
train_config = json.load(f)
return state_dict, config, train_config
def main():
parser = argparse.ArgumentParser(
description="Migrate policy models with normalization layers to new pipeline system"
)
parser.add_argument(
"--pretrained-path",
type=str,
required=True,
help="Path to pretrained model (hub repo or local directory)",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory for migrated model (default: same as pretrained-path)",
)
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
parser.add_argument(
"--hub-repo-id",
type=str,
default=None,
help="Hub repository ID for pushing (default: same as pretrained-path)",
)
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
args = parser.parse_args()
# Load model and config
print(f"Loading model from {args.pretrained_path}...")
if os.path.isdir(args.pretrained_path):
# Local directory
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
with open(os.path.join(args.pretrained_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
train_config = json.load(f)
else:
# Hub repository
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
# Extract normalization statistics
print("Extracting normalization statistics...")
stats = extract_normalization_stats(state_dict)
print(f"Found normalization statistics for: {list(stats.keys())}")
# Detect input features and normalization modes
print("Detecting features and normalization modes...")
features, norm_map = detect_features_and_norm_modes(config, stats)
print(f"Detected features: {list(features.keys())}")
print(f"Normalization modes: {norm_map}")
# Remove normalization layers from state_dict
print("Removing normalization layers from model...")
new_state_dict = remove_normalization_layers(state_dict)
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
if removed_keys:
print(f"Removed {len(removed_keys)} normalization layer keys")
# Determine output path
if args.output_dir:
output_dir = Path(args.output_dir)
else:
if os.path.isdir(args.pretrained_path):
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
else:
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
output_dir.mkdir(parents=True, exist_ok=True)
# Clean up config - remove normalization_mapping field
cleaned_config = dict(config)
if "normalization_mapping" in cleaned_config:
print("Removing 'normalization_mapping' field from config")
del cleaned_config["normalization_mapping"]
policy_type = deepcopy(cleaned_config["type"])
del cleaned_config["type"]
# Instantiate the policy model with cleaned config and load the cleaned state dict
print(f"Instantiating {policy_type} policy model...")
policy_class_path = POLICY_CLASSES[policy_type]
module_path, class_name = policy_class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
policy_class = getattr(module, class_name)
# Create config class instance
config_module_path = module_path.replace("modeling", "configuration")
config_module = importlib.import_module(config_module_path)
# Handle special cases for config class names
config_class_names = {
"act": "ACTConfig",
"diffusion": "DiffusionConfig",
"pi0": "PI0Config",
"pi0fast": "PI0FASTConfig",
"smolvla": "SmolVLAConfig",
"tdmpc": "TDMPCConfig",
"vqbet": "VQBeTConfig",
"sac": "SACConfig",
"classifier": "ClassifierConfig",
}
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
config_class = getattr(config_module, config_class_name)
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
if "input_features" not in cleaned_config:
raise ValueError("Missing mandatory 'input_features' in config")
if "output_features" not in cleaned_config:
raise ValueError("Missing mandatory 'output_features' in config")
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
# Create config instance from cleaned config dict
policy_config = config_class(**cleaned_config)
# Create policy instance - some policies expect dataset_stats
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Now create preprocessor and postprocessor with cleaned_config available
print("Creating preprocessor and postprocessor...")
# The pattern from existing processor factories:
# - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessor for output_features only
# Get features from cleaned_config (now they're PolicyFeature objects)
input_features = cleaned_config.get("input_features", {})
output_features = cleaned_config.get("output_features", {})
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**input_features, **output_features},
norm_map=norm_map,
stats=stats,
),
ToBatchProcessor(),
DeviceProcessor(device=policy_config.device),
]
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
]
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
# Determine hub repo ID if pushing to hub
if args.push_to_hub:
if args.hub_repo_id:
hub_repo_id = args.hub_repo_id
else:
if not os.path.isdir(args.pretrained_path):
# Use same repo with "_migrated" suffix
hub_repo_id = f"{args.pretrained_path}_migrated"
else:
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
else:
hub_repo_id = None
# Save preprocessor and postprocessor to root directory
print(f"Saving preprocessor to {output_dir}...")
preprocessor.save_pretrained(output_dir)
if args.push_to_hub:
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
print(f"Saving postprocessor to {output_dir}...")
postprocessor.save_pretrained(output_dir)
if args.push_to_hub:
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# Save model using the policy's save_pretrained method
print(f"Saving model to {output_dir}...")
policy.save_pretrained(
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
)
# Generate and save model card
print("Generating model card...")
# Get metadata from original config
dataset_repo_id = train_config.get("repo_id", "unknown")
license = config.get("license", "apache-2.0")
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
tags = set(tags).union({"robotics", "lerobot", policy_type})
tags = list(tags)
# Generate model card
card = policy.generate_model_card(
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
)
# Save model card locally
card.save(str(output_dir / "README.md"))
print(f"Model card saved to {output_dir / 'README.md'}")
# Push model card to hub if requested
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=str(output_dir / "README.md"),
path_in_repo="README.md",
repo_id=hub_repo_id,
repo_type="model",
commit_message="Add model card for migrated model",
)
print("Model card pushed to hub")
print("\nMigration complete!")
print(f"Migrated model saved to: {output_dir}")
if args.push_to_hub:
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
@@ -11,7 +10,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor, TransitionKey
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
@@ -116,7 +115,7 @@ class NormalizerProcessor:
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation, normalized_info):
def _normalize_obs(self, observation):
if observation is None:
return None
@@ -129,20 +128,7 @@ class NormalizerProcessor:
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self.features:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
@@ -153,35 +139,16 @@ class NormalizerProcessor:
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
normalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
normalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
return processed
def _normalize_action(self, action, normalized_info):
if action is None:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
def _normalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
@@ -190,42 +157,22 @@ class NormalizerProcessor:
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
normalized_info["action"] = "MEAN_STD"
return (tensor - mean) / (std + self.eps)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
normalized_info["action"] = "MIN_MAX"
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return (tensor - mean) / (std + self.eps)
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Track what was normalized
normalized_info = {}
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info)
action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info)
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add normalization info to complementary data
if normalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["normalized_keys"] = normalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -257,7 +204,7 @@ class NormalizerProcessor:
def reset(self):
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -306,28 +253,14 @@ class UnnormalizerProcessor:
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation, unnormalized_info):
def _unnormalize_obs(self, observation):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self.features:
if key not in processed or key not in self._tensor_stats:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
@@ -335,80 +268,39 @@ class UnnormalizerProcessor:
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
unnormalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
unnormalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
return processed
def _unnormalize_action(self, action, unnormalized_info):
if action is None:
def _unnormalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
unnormalized_info["action"] = "MEAN_STD"
return tensor * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
unnormalized_info["action"] = "MIN_MAX"
return (tensor + 1) / 2 * (max_val - min_val) + min_val
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return tensor * std + mean
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return (tensor + 1) / 2 * (max_val - min_val) + min_val
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Track what was unnormalized
unnormalized_info = {}
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info)
action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info)
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add unnormalization info to complementary data
if unnormalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["unnormalized_keys"] = unnormalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -435,41 +327,5 @@ class UnnormalizerProcessor:
def reset(self):
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
robot_processor = deepcopy(robot_processor)
for step in robot_processor.steps:
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
step: NormalizerProcessor | UnnormalizerProcessor
step.stats = stats
step._tensor_stats = _convert_stats_to_tensors(stats)
return robot_processor
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to the provided mapping.
Args:
stats: The statistics dictionary with structure {feature_key: {stat_name: value}}
rename_map: Dictionary mapping old key names to new key names
Returns:
A new stats dictionary with renamed keys
Example:
>>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}}
>>> rename_map = {"observation.state": "observation.robot_state"}
>>> new_stats = rename_stats(stats, rename_map)
>>> # new_stats will have "observation.robot_state" instead of "observation.state"
"""
renamed_stats = {}
for old_key, sub_stats in stats.items():
# Use the new key if it exists in the rename map, otherwise keep the old key
new_key = rename_map.get(old_key, old_key)
renamed_stats[new_key] = deepcopy(sub_stats)
return renamed_stats

View File

@@ -106,8 +106,9 @@ class VanillaObservationProcessor(ObservationProcessor):
def observation(self, observation):
return self._process_observation(observation)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').

View File

@@ -23,7 +23,7 @@ from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Protocol, TypedDict, runtime_checkable
from typing import Any, Protocol, TypedDict
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
@@ -132,7 +132,6 @@ class ProcessorStepRegistry:
cls._registry.clear()
@runtime_checkable
class ProcessorStep(Protocol):
"""Structural typing interface for a single processor step.
@@ -146,6 +145,7 @@ class ProcessorStep(Protocol):
**Required**:
- ``__call__(transition: EnvTransition) -> EnvTransition``
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
Optional helper protocol:
* ``get_config() -> dict[str, Any]`` User-defined JSON-serializable
@@ -158,8 +158,6 @@ class ProcessorStep(Protocol):
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict
containing torch tensors only.
* ``reset()`` Clear internal buffers at episode boundaries.
* ``transform_features(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
If present, this method will be called to aggregate the dataset features of all steps.
Example separation:
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
@@ -176,7 +174,7 @@ class ProcessorStep(Protocol):
def reset(self) -> None: ...
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
@@ -203,16 +201,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
# Extract padding, task, index, and task_index keys for complementary data
# Extract padding and task keys for complementary data
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
complementary_data = (
{**pad_keys, **task_key, **index_key, **task_index_key}
if pad_keys or task_key or index_key or task_index_key
else {}
)
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
transition: EnvTransition = {
TransitionKey.OBSERVATION: observation,
@@ -239,7 +231,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"info": transition.get(TransitionKey.INFO, {}),
}
# Add padding, task, index, and task_index data from complementary_data
# Add padding and task data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
@@ -248,12 +240,6 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
if "task" in complementary_data:
batch["task"] = complementary_data["task"]
if "index" in complementary_data:
batch["index"] = complementary_data["index"]
if "task_index" in complementary_data:
batch["task_index"] = complementary_data["task_index"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
@@ -356,10 +342,7 @@ class RobotProcessor(ModelHubMixin):
hook(idx, current_transition)
# Convert back to original format if needed
if called_with_batch or self.to_output is not _default_transition_to_batch:
return self.to_output(current_transition)
else:
return current_transition
return self.to_output(current_transition) if called_with_batch else current_transition
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
"""Prepare and validate transition data for processing.
@@ -592,9 +575,10 @@ class RobotProcessor(ModelHubMixin):
if config_filename is None:
# Try common config names
common_names = [
"robot_processor.json",
"robot_preprocessor.json",
"robot_postprocessor.json",
"processor.json",
"preprocessor.json",
"postprocessor.json",
"robotprocessor.json",
]
config_path = None
for name in common_names:
@@ -824,15 +808,23 @@ class RobotProcessor(ModelHubMixin):
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
fc = getattr(step, "feature_contract", None)
if not callable(fc):
raise TypeError(
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
)
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Apply ALL steps in order. Only if a step has a features method, it will be called.
We aggregate the dataset features of all steps.
Apply ALL steps in order. Each step must implement
feature_contract(features) and return a dict (full or incremental schema).
"""
features: dict[str, PolicyFeature] = deepcopy(initial_features)
for _, step in enumerate(self.steps):
out = step.transform_features(features)
out = step.feature_contract(features)
if not isinstance(out, dict):
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
features = out
return features
@@ -892,7 +884,7 @@ class ObservationProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -952,7 +944,7 @@ class ActionProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1011,7 +1003,7 @@ class RewardProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1075,7 +1067,7 @@ class DoneProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1135,7 +1127,7 @@ class TruncatedProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1200,7 +1192,7 @@ class InfoProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1246,7 +1238,7 @@ class ComplementaryDataProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1268,5 +1260,5 @@ class IdentityProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -43,7 +43,7 @@ class RenameProcessor(ObservationProcessor):
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
- Keys not in `rename_map` remain unchanged.

View File

@@ -1,275 +0,0 @@
"""
Tokenizer processor for handling text tokenization in robot transitions.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
else:
AutoTokenizer = None
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessor:
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
This processor handles tokenization of task strings found in the complementary_data
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
to the observation data for model processing while preserving the original task string.
The processor supports both single strings and lists of strings as task inputs.
Args:
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
tokenizer: A tokenizer object (e.g., from transformers library) that implements
the __call__ method. If provided, tokenizer_name is ignored. This parameter
is not serialized and must be provided via overrides when loading.
max_length: Maximum sequence length for tokenization. Defaults to 512.
task_key: Key in complementary_data containing the task text. Defaults to "task".
padding: Padding strategy for tokenization. Defaults to "max_length".
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
Examples:
Using tokenizer name (auto-loaded):
```python
processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128)
```
Using custom tokenizer object:
```python
from transformers import AutoTokenizer
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128)
```
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
# Internal tokenizer instance (not serialized)
_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self._tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoTokenizer is None:
raise ImportError("AutoTokenizer is not available")
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
def get_task(self, transition: EnvTransition) -> list[str] | None:
"""Extract and normalize task from complementary data.
Args:
transition: Input transition containing complementary_data.
Returns:
List of task strings if task is present, None otherwise.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
if self.task_key not in complementary_data:
return None
task = complementary_data[self.task_key]
if task is None:
return None
# Convert to list of strings
if isinstance(task, str):
return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
return task
return None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process the transition by tokenizing the task text.
Args:
transition: Input transition containing complementary_data with task text.
Returns:
Modified transition with tokenized task added to observation.
Raises:
ValueError: If tokenizer initialization failed.
"""
task = self.get_task(transition)
if task is None:
return transition
# Tokenize the task (creates CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect device from existing tensors in the transition
target_device = self._detect_device(transition)
# Move tokenized tensors to match the device of other data
if target_device is not None:
tokenized_prompt = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_prompt.items()
}
# Get or create observation dict
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
observation = {}
else:
observation = dict(observation) # Make a copy
# Add tokenized data to observation
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
dtype=torch.bool
)
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
return transition
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
"""Detect device from existing tensors in the transition.
This allows the tokenized tensors to match the device of other data,
which is especially important for multi-GPU training with Accelerate.
Args:
transition: The transition to search for existing tensors.
Returns:
The device of the first tensor found, or None if no tensors exist.
"""
# Check observation tensors first (most likely to exist)
observation = transition.get(TransitionKey.OBSERVATION)
if observation:
for value in observation.values():
if isinstance(value, torch.Tensor):
return value.device
# Check action tensor
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
return action.device
# Check other tensor fields
for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
value = transition.get(key)
if isinstance(value, torch.Tensor):
return value.device
# Check complementary data for tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
for value in complementary_data.values():
if isinstance(value, torch.Tensor):
return value.device
return None # No tensors found, keep on CPU
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.
Args:
text: Text string or list of strings to tokenize.
Returns:
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
"""
return self._tokenizer(
text,
max_length=self.max_length,
truncation=self.truncation,
padding=self.padding,
padding_side=self.padding_side,
return_tensors="pt",
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.
Note: Only tokenizer_name is saved, not the tokenizer object itself.
When loading, provide the tokenizer via overrides if needed.
"""
config = {
"max_length": self.max_length,
"task_key": self.task_key,
"padding_side": self.padding_side,
"padding": self.padding,
"truncation": self.truncation,
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
if self.tokenizer_name is not None:
config["tokenizer_name"] = self.tokenizer_name
return config
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
Args:
features: Input feature dictionary.
Returns:
Updated feature dictionary with tokenized task features added.
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
tokens_key = f"{OBS_LANGUAGE}.tokens"
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
if tokens_key not in features:
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if attention_mask_key not in features:
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
return features

View File

@@ -59,7 +59,7 @@ lerobot-record \
import logging
import time
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
@@ -72,19 +72,10 @@ from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import (
to_dataset_frame,
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
)
from lerobot.processor.normalize_processor import rename_stats
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -158,8 +149,6 @@ class DatasetRecordConfig:
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
if self.single_task is None:
@@ -198,36 +187,6 @@ class RecordConfig:
return ["policy"]
""" --------------- record_loop() data flow --------------------------
[ Robot ]
V
[ robot.get_observation() ] ---> raw_obs
V
[ robot_observation_processor ] ---> obs_transition
V
.-----( ACTION LOGIC )------------------.
V V
[ From Teleoperator ] [ From Policy ]
| |
| [teleop.get_action] -> raw_action | [predict_action]
| | | |
| V | V
| [teleop_action_processor] | |
| | | |
'---> teleop_transition '---> policy_transition
| |
'-------------------------.-------------'
V
[ robot_action_processor ] --> robot_action_to_send
V
[ robot.send_action() ] -- (Robot Executes)
V
( Transitions are merged & added to Dataset )
V
( Rerun Log / Loop Wait )
"""
@safe_stop_image_writer
def record_loop(
robot: Robot,
@@ -236,36 +195,28 @@ def record_loop(
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
preprocessor: RobotProcessor | None = None,
postprocessor: RobotProcessor | None = None,
control_time_s: int | None = None,
teleop_action_processor: RobotProcessor | None = None, # runs after teleop
robot_action_processor: RobotProcessor | None = None, # runs before robot
robot_observation_processor: RobotProcessor | None = None, # runs after robot
single_task: str | None = None,
display_data: bool = False,
):
teleop_action_processor = teleop_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
)
robot_action_processor = robot_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action
)
robot_observation_processor = robot_observation_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
)
if dataset is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
teleop_arm = teleop_keyboard = None
if isinstance(teleop, list): # For LeKiwi
if isinstance(teleop, list):
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
teleop_arm = next(
(
t
for t in teleop
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
if isinstance(
t,
(
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
)
),
None,
)
@@ -275,20 +226,9 @@ def record_loop(
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
)
# Reset policy and processor if they are provided
if policy is not None and preprocessor is not None and postprocessor is not None:
# if policy is given it needs cleaning up
if policy is not None:
policy.reset()
preprocessor.reset()
postprocessor.reset()
# Reset custom pipelines
teleop_action_processor.reset()
robot_action_processor.reset()
robot_observation_processor.reset()
policy_transition = None
teleop_transition = None
obs_transition = None
timestamp = 0
start_episode_t = time.perf_counter()
@@ -299,87 +239,51 @@ def record_loop(
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
observation = robot.get_observation()
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
obs_transition = robot_observation_processor(obs)
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
if dataset is not None:
observation_frame = to_dataset_frame(
obs_transition, dataset.features
) # Convert the observation to the dataset format
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
observation_frame,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
action_names = dataset.features["action"]["names"]
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
policy_transition = {
TransitionKey.ACTION: policy_action,
TransitionKey.COMPLEMENTARY_DATA: {},
}
elif isinstance(teleop, Teleoperator):
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
teleop_transition = teleop_action_processor(act)
elif isinstance(teleop, list):
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
elif policy is None and isinstance(teleop, Teleoperator):
action = teleop.get_action()
elif policy is None and isinstance(teleop, list):
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
arm_action = teleop_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
keyboard_action = teleop_keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
teleop_transition = teleop_action_processor(act)
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
else:
logging.info(
"No policy or teleoperator provided, skipping action generation. "
"This is likely to happen during environment reset."
"No policy or teleoperator provided, skipping action generation."
"This is likely to happen when resetting the environment without a teleop device."
"The robot won't be at its rest position at the start of the next episode."
)
# Still continue to next loop to respect timing
continue
# Applies a pipeline to the action, default is IdentityProcessor
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
if policy_transition is not None:
robot_action_to_send = robot_action_processor(policy_transition)
else:
robot_action_to_send = robot_action_processor(teleop_transition)
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_ = robot.send_action(robot_action_to_send)
# so action actually sent is saved in the dataset.
sent_action = robot.send_action(action)
# Write to dataset
if dataset is not None:
# If to_dataset_frame is provided, use it to merge the transitions.
merged = []
if obs_transition is not None: # The observation from the robot
merged.append(obs_transition)
if teleop_transition is not None: # The action from teleop
merged.append(teleop_transition)
if policy_transition is not None: # The action from policy
merged.append(policy_transition)
frame = to_dataset_frame(
merged if len(merged) > 1 else merged[0], dataset.features
) # Convert the observation to the dataset format
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
if display_data:
log_rerun_data([obs_transition, teleop_transition or policy_transition])
log_rerun_data(observation, action)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
@@ -431,18 +335,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
if cfg.policy is not None:
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
if teleop is not None:
@@ -460,8 +352,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
@@ -510,5 +400,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
return dataset
if __name__ == "__main__":
def main():
record()
if __name__ == "__main__":
main()

View File

@@ -55,6 +55,7 @@ from lerobot.robots import ( # noqa: F401
hope_jr,
koch_follower,
make_robot_from_config,
reachy2,
so100_follower,
so101_follower,
)

View File

@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
# Optional
left_arm_disable_torque_on_disconnect: bool = True
left_arm_max_relative_target: int | None = None
left_arm_max_relative_target: float | dict[str, float] | None = None
left_arm_use_degrees: bool = False
right_arm_disable_torque_on_disconnect: bool = True
right_arm_max_relative_target: int | None = None
right_arm_max_relative_target: float | dict[str, float] | None = None
right_arm_use_degrees: bool = False
# cameras (shared between both arms)

View File

@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -110,6 +110,7 @@ class KochFollower(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
@@ -120,7 +121,6 @@ class KochFollower(Robot):
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)

View File

@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_reachy2 import Reachy2RobotConfig
from .robot_reachy2 import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Robot,
)

View File

@@ -0,0 +1,107 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.cameras.configs import ColorMode
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("reachy2")
@dataclass
class Reachy2RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors.
max_relative_target: float | None = None
# IP address of the Reachy 2 robot
ip_address: str | None = "localhost"
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
disable_torque_on_disconnect: bool = False
# Tag for external commands control
# Set to True if you use an external commands system to control the robot,
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
# If True, robot.send_action() will not send commands to the robot.
use_external_commands: bool = False
# Robot parts
# Set to False to not add the corresponding joints part to the robot list of joints.
# By default, all parts are set to True.
with_mobile_base: bool = True
with_l_arm: bool = True
with_r_arm: bool = True
with_neck: bool = True
with_antennas: bool = True
# Robot cameras
# Set to True if you want to use the corresponding cameras in the observations.
# By default, only the teleop cameras are used.
with_left_teleop_camera: bool = True
with_right_teleop_camera: bool = True
with_torso_camera: bool = False
cameras: dict[str, CameraConfig] = field(default_factory=dict)
def __post_init__(self) -> None:
# Add cameras with same ip_address as the robot
if self.with_left_teleop_camera:
self.cameras["teleop_left"] = Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_right_teleop_camera:
self.cameras["teleop_right"] = Reachy2CameraConfig(
name="teleop",
image_type="right",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_torso_camera:
self.cameras["torso_rgb"] = Reachy2CameraConfig(
name="depth",
image_type="rgb",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
super().__post_init__()
if not (
self.with_mobile_base
or self.with_l_arm
or self.with_r_arm
or self.with_neck
or self.with_antennas
):
raise ValueError(
"No Reachy2Robot part used.\n"
"At least one part of the robot must be set to True "
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
)

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any
import numpy as np
from reachy2_sdk import ReachySDK
from lerobot.cameras.utils import make_cameras_from_configs
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .configuration_reachy2 import Reachy2RobotConfig
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw",
"neck_pitch.pos": "head.neck.pitch",
"neck_roll.pos": "head.neck.roll",
}
REACHY2_ANTENNAS_JOINTS = {
"l_antenna.pos": "head.l_antenna",
"r_antenna.pos": "head.r_antenna",
}
REACHY2_R_ARM_JOINTS = {
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
"r_wrist_roll.pos": "r_arm.wrist.roll",
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
"r_gripper.pos": "r_arm.gripper",
}
REACHY2_L_ARM_JOINTS = {
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
"l_wrist_roll.pos": "l_arm.wrist.roll",
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
"l_gripper.pos": "l_arm.gripper",
}
REACHY2_VEL = {
"mobile_base.vx": "vx",
"mobile_base.vy": "vy",
"mobile_base.vtheta": "vtheta",
}
class Reachy2Robot(Robot):
"""
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
"""
config_class = Reachy2RobotConfig
name = "reachy2"
def __init__(self, config: Reachy2RobotConfig):
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.use_external_commands = self.config.use_external_commands
self.reachy: None | ReachySDK = None
self.cameras = make_cameras_from_configs(config.cameras)
self.logs: dict[str, float] = {}
self.joints_dict: dict[str, str] = self._generate_joints_dict()
@property
def observation_features(self) -> dict[str, Any]:
return {**self.motors_features, **self.camera_features}
@property
def action_features(self) -> dict[str, type]:
return self.motors_features
@property
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
@property
def motors_features(self) -> dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
self.joints_dict.keys(),
float,
),
**dict.fromkeys(
REACHY2_VEL.keys(),
float,
),
}
else:
return dict.fromkeys(self.joints_dict.keys(), float)
@property
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = False) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
raise ConnectionError()
for cam in self.cameras.values():
cam.connect()
self.configure()
def configure(self) -> None:
if self.reachy is not None:
self.reachy.turn_on()
self.reachy.reset_default_limits()
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
def _generate_joints_dict(self) -> dict[str, str]:
joints = {}
if self.config.with_neck:
joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
joints.update(REACHY2_ANTENNAS_JOINTS)
return joints
def _get_state(self) -> dict[str, float]:
if self.reachy is not None:
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
return pos_dict
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
return {**pos_dict, **vel_dict}
else:
return {}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict: dict[str, Any] = {}
# Read Reachy 2 state
before_read_t = time.perf_counter()
obs_dict.update(self._get_state())
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
# Capture images from cameras
for cam_key, cam in self.cameras.items():
obs_dict[cam_key] = cam.async_read()
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if self.reachy is not None:
if not self.is_connected:
raise ConnectionError()
before_write_t = time.perf_counter()
vel = {}
goal_pos = {}
for key, val in action.items():
if key not in self.joints_dict:
if key not in REACHY2_VEL:
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
else:
vel[REACHY2_VEL[key]] = float(val)
else:
if not self.use_external_commands and self.config.max_relative_target is not None:
goal_pos[key] = float(val)
goal_present_pos = {
key: (
goal_pos[key],
self.reachy.joints[self.joints_dict[key]].present_position,
)
}
safe_goal_pos = ensure_safe_goal_position(
goal_present_pos, float(self.config.max_relative_target)
)
val = safe_goal_pos[key]
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
if self.config.with_mobile_base:
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
# We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
return action
def disconnect(self) -> None:
if self.reachy is not None:
for cam in self.cameras.values():
cam.disconnect()
if self.config.disable_torque_on_disconnect:
self.reachy.turn_off_smoothly()
self.reachy.disconnect()

View File

@@ -14,5 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_so100_follower import SO100FollowerConfig
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
from .so100_follower import SO100Follower
from .so100_follower_end_effector import SO100FollowerEndEffector

View File

@@ -30,12 +30,44 @@ class SO100FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
@RobotConfig.register_subclass("so100_follower_end_effector")
@dataclass
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
"""Configuration for the SO100FollowerEndEffector robot."""
# Path to URDF file for kinematics
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
urdf_path: str | None = None
# End-effector frame name in the URDF
target_frame_name: str = "gripper_frame_link"
# Default bounds for the end-effector position (in meters)
end_effector_bounds: dict[str, list[float]] = field(
default_factory=lambda: {
"min": [-1.0, -1.0, -1.0], # min x, y, z
"max": [1.0, 1.0, 1.0], # max x, y, z
}
)
max_gripper_pos: float = 50
end_effector_step_sizes: dict[str, float] = field(
default_factory=lambda: {
"x": 0.02,
"y": 0.02,
"z": 0.02,
}
)

View File

@@ -1,465 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import numpy as np
from scipy.spatial.transform import Rotation
from lerobot.configs.types import PolicyFeature
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
ObservationProcessor,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.robots.robot import Robot
@ProcessorStepRegistry.register("ee_reference_and_delta")
@dataclass
class EEReferenceAndDelta:
"""
Compute the desired end-effector pose from the target pose and the current pose.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
"""
kinematics: RobotKinematics
end_effector_step_sizes: dict
motor_names: list[str]
use_latched_reference: bool = (
True # If True, latch reference on enable; if False, always use current pose
)
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
_prev_enabled: bool = field(default=False, init=False, repr=False)
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
)
if "reference_joint_positions" in comp:
q = comp["reference_joint_positions"]
else:
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
# Current pose from FK on measured joints
t_curr = self.kinematics.forward_kinematics(q)
enabled = bool(act.pop("action.enabled", 0))
tx = float(act.pop("action.target_x", 0.0))
ty = float(act.pop("action.target_y", 0.0))
tz = float(act.pop("action.target_z", 0.0))
wx = float(act.pop("action.target_wx", 0.0))
wy = float(act.pop("action.target_wy", 0.0))
wz = float(act.pop("action.target_wz", 0.0))
desired = None
if enabled:
ref = t_curr
if self.use_latched_reference:
# Latched reference mode: latch reference at the rising edge
if not self._prev_enabled or self.reference_ee_pose is None:
self.reference_ee_pose = t_curr.copy()
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
delta_p = np.array(
[
tx * self.end_effector_step_sizes["x"],
ty * self.end_effector_step_sizes["y"],
tz * self.end_effector_step_sizes["z"],
],
dtype=float,
)
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
desired = np.eye(4, dtype=float)
desired[:3, :3] = ref[:3, :3] @ r_abs
desired[:3, 3] = ref[:3, 3] + delta_p
self._command_when_disabled = desired.copy()
else:
# While disabled, keep sending the same command to avoid drift.
if self._command_when_disabled is None:
# If we've never had an enabled command yet, freeze current FK pose once.
self._command_when_disabled = t_curr.copy()
desired = self._command_when_disabled.copy()
# Write action fields
pos = desired[:3, 3]
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
act.update(
{
"action.ee.x": float(pos[0]),
"action.ee.y": float(pos[1]),
"action.ee.z": float(pos[2]),
"action.ee.wx": float(tw[0]),
"action.ee.wy": float(tw[1]),
"action.ee.wz": float(tw[2]),
}
)
self._prev_enabled = enabled
transition[TransitionKey.ACTION] = act
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("ee_bounds_and_safety")
@dataclass
class EEBoundsAndSafety(ActionProcessor):
"""
Clip the end-effector pose to the bounds and check for jumps.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
"""
end_effector_bounds: dict
max_ee_step_m: float = 0.05
max_ee_twist_step_rad: float = 0.20
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, act: dict | None) -> dict:
x = act.pop("action.ee.x", None)
y = act.pop("action.ee.y", None)
z = act.pop("action.ee.z", None)
wx = act.pop("action.ee.wx", None)
wy = act.pop("action.ee.wy", None)
wz = act.pop("action.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return act
pos = np.array([x, y, z], dtype=float)
twist = np.array([wx, wy, wz], dtype=float)
# Clip position
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
# Check for jumps in position
if self._last_pos is not None:
dpos = pos - self._last_pos
n = float(np.linalg.norm(dpos))
if n > self.max_ee_step_m and n > 0:
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
self._last_pos = pos
self._last_twist = twist
act.update(
{
"action.ee.x": float(pos[0]),
"action.ee.y": float(pos[1]),
"action.ee.z": float(pos[2]),
"action.ee.wx": float(twist[0]),
"action.ee.wy": float(twist[1]),
"action.ee.wz": float(twist[2]),
}
)
return act
def reset(self):
self._last_pos = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# Because this is last step we specify the dataset features of this step that we want to be stored in the dataset
features["action.ee.x"] = float
features["action.ee.y"] = float
features["action.ee.z"] = float
features["action.ee.wx"] = float
features["action.ee.wy"] = float
features["action.ee.wz"] = float
return features
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
@dataclass
class InverseKinematicsEEToJoints:
"""
Compute the desired joint positions from the desired end-effector pose.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
Output ACTION keys:
{
"action.joint_name_1.pos": float,
"action.joint_name_2.pos": float,
...
"action.joint_name_n.pos": float,
}
"""
kinematics: RobotKinematics
motor_names: list[str]
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
initial_guess_current_joints: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get("action.ee.x", None)
y = act.get("action.ee.y", None)
z = act.get("action.ee.z", None)
wx = act.get("action.ee.wx", None)
wy = act.get("action.ee.wy", None)
wz = act.get("action.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
# Nothing to do; restore what we popped and return
act.update(
{
"action.ee.x": x,
"action.ee.y": y,
"action.ee.z": z,
"action.ee.wx": wx,
"action.ee.wy": wy,
"action.ee.wz": wz,
}
)
transition[TransitionKey.ACTION] = act
return transition
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
)
if self.initial_guess_current_joints: # Use current joints as initial guess
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
else: # Use previous ik solution as initial guess
if self.q_curr is None:
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
# Build desired 4x4 transform from pos + rotvec (twist)
t_des = np.eye(4, dtype=float)
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
t_des[:3, 3] = [x, y, z]
# Compute inverse kinematics
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
self.q_curr = q_target
new_act = dict(act)
for i, name in enumerate(self.motor_names):
if name == "gripper":
new_act["observation.state.gripper.pos"] = float(raw["gripper"])
else:
new_act[f"action.{name}.pos"] = float(q_target[i])
transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints:
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
features["action.ee.x"] = float
features["action.ee.y"] = float
features["action.ee.z"] = float
features["action.ee.wx"] = float
features["action.ee.wy"] = float
features["action.ee.wz"] = float
features["observation.state.gripper.pos"] = float
features["action.gripper.pos"] = float
return features
def reset(self):
self.q_curr = None
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
@dataclass
class GripperVelocityToJoint:
"""
Convert the gripper velocity to a joint velocity.
Input ACTION keys:
{
"action.gripper": float,
}
Output ACTION keys:
{
"action.gripper.pos": float,
}
"""
motor_names: list[str]
speed_factor: float = 20.0
clip_min: float = 0.0
clip_max: float = 100.0
discrete_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION) or {}
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if "action.gripper" not in act:
return transition
if "gripper" not in self.motor_names:
new_act = dict(act)
new_act.pop("action.gripper", None)
transition[TransitionKey.ACTION] = new_act
return transition
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_action = act.get("action.gripper", 1.0)
gripper_action = gripper_action - 1.0
gripper_action *= self.clip_max
act["action.gripper"] = gripper_action
# Get current gripper position from complementary data
raw = comp.get("raw_joint_positions") or {}
curr_pos = float(raw.get("gripper"))
# Compute desired gripper velocity
u = float(act.get("action.gripper", 0.0))
delta = u * float(self.speed_factor)
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
new_act = dict(act)
new_act["action.gripper.pos"] = gripper_pos
new_act.pop("action.gripper", None)
transition[TransitionKey.ACTION] = new_act
obs.update({"observation.state.gripper.pos": curr_pos})
transition[TransitionKey.OBSERVATION] = obs
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
features["observation.state.gripper.pos"] = float
features["action.gripper.pos"] = float
return features
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
@dataclass
class ForwardKinematicsJointsToEE(ObservationProcessor):
"""
Compute the end-effector pose from the joint positions.
Input OBSERVATION keys:
{
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
}
Output OBSERVATION keys:
{
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
}
"""
kinematics: RobotKinematics
motor_names: list[str]
def observation(self, obs: dict | None) -> dict:
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names):
return obs
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
obs.update(
{
"observation.state.ee.x": float(pos[0]),
"observation.state.ee.y": float(pos[1]),
"observation.state.ee.z": float(pos[2]),
"observation.state.ee.wx": float(tw[0]),
"observation.state.ee.wy": float(tw[1]),
"observation.state.ee.wz": float(tw[2]),
}
)
return obs
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz"]:
features[f"observation.state.ee.{k}"] = float
return features
@ProcessorStepRegistry.register("add_robot_observation")
@dataclass
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
"""
Read the robot's current observation and insert it into the transition as complementary data.
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
{ "<motor_name>": <float position>, ... }
"""
robot: Robot
def complementary_data(self, comp: dict | None) -> dict:
comp = {} if comp is None else dict(comp)
obs = self.robot.get_observation()
comp["raw_joint_positions"] = {
k.removesuffix(".pos"): float(v)
for k, v in obs.items()
if isinstance(k, str) and k.endswith(".pos")
}
return comp
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -161,6 +161,11 @@ class SO100Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
if motor == "gripper":
self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")

View File

@@ -0,0 +1,200 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
import numpy as np
from lerobot.cameras import make_cameras_from_configs
from lerobot.errors import DeviceNotConnectedError
from lerobot.model.kinematics import RobotKinematics
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.feetech import FeetechMotorsBus
from . import SO100Follower
from .config_so100_follower import SO100FollowerEndEffectorConfig
logger = logging.getLogger(__name__)
class SO100FollowerEndEffector(SO100Follower):
"""
SO100Follower robot with end-effector space control.
This robot inherits from SO100Follower but transforms actions from
end-effector space to joint space before sending them to the motors.
"""
config_class = SO100FollowerEndEffectorConfig
name = "so100_follower_end_effector"
def __init__(self, config: SO100FollowerEndEffectorConfig):
super().__init__(config)
self.bus = FeetechMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
self.config = config
# Initialize the kinematics module for the so100 robot
if self.config.urdf_path is None:
raise ValueError(
"urdf_path must be provided in the configuration for end-effector control. "
"Please set urdf_path in your SO100FollowerEndEffectorConfig."
)
self.kinematics = RobotKinematics(
urdf_path=self.config.urdf_path,
target_frame_name=self.config.target_frame_name,
)
# Store the bounds for end-effector position
self.end_effector_bounds = self.config.end_effector_bounds
self.current_ee_pos = None
self.current_joint_pos = None
@property
def action_features(self) -> dict[str, Any]:
"""
Define action features for end-effector control.
Returns dictionary with dtype, shape, and names.
"""
return {
"dtype": "float32",
"shape": (4,),
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""
Transform action from end-effector space to joint space and send to motors.
Args:
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
or a numpy array with [delta_x, delta_y, delta_z]
Returns:
The joint-space action that was sent to the motors
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Convert action to numpy array if not already
if isinstance(action, dict):
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
delta_ee = np.array(
[
action["delta_x"] * self.config.end_effector_step_sizes["x"],
action["delta_y"] * self.config.end_effector_step_sizes["y"],
action["delta_z"] * self.config.end_effector_step_sizes["z"],
],
dtype=np.float32,
)
if "gripper" not in action:
action["gripper"] = [1.0]
action = np.append(delta_ee, action["gripper"])
else:
logger.warning(
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
)
action = np.zeros(4, dtype=np.float32)
if self.current_joint_pos is None:
# Read current joint positions
current_joint_pos = self.bus.sync_read("Present_Position")
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
# Calculate current end-effector position using forward kinematics
if self.current_ee_pos is None:
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
# Set desired end-effector position by adding delta
desired_ee_pos = np.eye(4)
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
# Add delta to position and clip to bounds
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
if self.end_effector_bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3],
self.end_effector_bounds["min"],
self.end_effector_bounds["max"],
)
# Compute inverse kinematics to get joint positions
target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
self.current_joint_pos, desired_ee_pos
)
# Create joint space action dictionary
joint_action = {
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
}
# Handle gripper separately if included in action
# Gripper delta action is in the range 0 - 2,
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
joint_action["gripper.pos"] = np.clip(
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
5,
self.config.max_gripper_pos,
)
self.current_ee_pos = desired_ee_pos.copy()
self.current_joint_pos = target_joint_values_in_degrees.copy()
self.current_joint_pos[-1] = joint_action["gripper.pos"]
# Send joint space action to parent class
return super().send_action(joint_action)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def reset(self):
self.current_ee_pos = None
self.current_joint_pos = None

Some files were not shown because too many files have changed in this diff Show More