Compare commits

..

151 Commits

Author SHA1 Message Date
Pepijn
42d35e47b2 fix port 2025-09-08 18:58:42 +02:00
Pepijn
1b9330a25a fix import 2025-09-08 18:49:26 +02:00
Pepijn
490ffa89a5 add download script 2025-09-08 17:57:04 +02:00
Pepijn
3d31f2ad53 add port rlds script 2025-09-08 13:40:47 +02:00
Michel Aractingi
af79dda8d9 fix(caching) remove cache dir when collecting a dataset with each call to load_episodes and load_hf_dataset 2025-09-08 12:44:43 +02:00
Michel Aractingi
952f455446 fix(bug) in save_episode_data 2025-09-06 00:17:46 +02:00
Michel Aractingi
0747afdba7 Optimize dataset updates by incrementally concatenating new data instead of reloading from disk, reducing memory usage and improving performance. 2025-09-05 18:37:48 +02:00
Michel Aractingi
992fb177c3 further memory optimizations needed due to calling pd.concat 2025-09-03 18:49:30 +02:00
Michel Aractingi
1db3401159 remove unused Iterable Namespace 2025-09-03 16:23:36 +02:00
pre-commit-ci[bot]
7868df27dc [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-03 14:18:02 +00:00
Michel Aractingi
0e04f5fbbe remove html templates and flask dependency 2025-09-03 16:17:10 +02:00
Michel Aractingi
fdccf7774b fix(memory explosion) added delete to episodes and hf_dataset everytime we reload while collecting a dataset ot avoid memroy explosion 2025-09-03 15:31:28 +02:00
Michel Aractingi
2a3d62259e visualize_dataset_html deprecated 2025-09-02 15:51:11 +02:00
Michel Aractingi
2df4e25558 added the file and video max size as arguments 2025-09-02 15:41:42 +02:00
pre-commit-ci[bot]
4062d0564a [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-09-01 19:32:19 +00:00
CarolinePascal
0a30636fc6 chore(dataset v2.0): drop support for dataset v2.0 format 2025-09-01 21:31:46 +02:00
CarolinePascal
adad3698e1 chore(dataset v1): drop support for dataset v1 format 2025-09-01 19:37:20 +02:00
Michel Aractingi
84ffc28854 moved get_video_duration_in_s to video_utils and replaced subprocess and ffmpeg with pyAV 2025-08-29 01:31:53 +02:00
Michel Aractingi
47aee1fdbe revert back video_utils.py to using pyav while keeping concat_video_files function 2025-08-29 01:06:46 +02:00
Michel Aractingi
bbd64b9ce5 fixes in datasets/utils.py 2025-08-29 00:03:13 +02:00
Michel Aractingi
000e88760d removed unused functions from tests/fixtures 2025-08-28 11:18:36 +02:00
Michel Aractingi
35f36e8fba removed outdated todos 2025-08-28 10:10:17 +02:00
pre-commit-ci[bot]
213ffe02cf [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-28 07:53:10 +00:00
Michel Aractingi
2b03dec01f Removed .item from save_dataset_to_safetensors 2025-08-28 09:52:37 +02:00
Michel Aractingi
64a9dd3763 Removed agibot files and moved port_droid to port_datasets 2025-08-28 09:44:38 +02:00
Francesco Capuano
2ca6edc19e Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3
Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
2025-08-25 16:34:44 +02:00
mgiac-hexagon
577cd10974 Removed dupicate lines of code (#1709) 2025-08-25 12:39:32 +02:00
lxk
b0923ab74b fix(dataset): Use provided episode_data in save_episode (#1740)
The 'episode_data' parameter was previously ignored, causing an error if provided. This change ensures it is correctly used, which allows for asynchronous episode saving by passing a copy of the episode buffer, preventing conflicts with the main data collection loop.
2025-08-22 15:24:02 +02:00
Jack Vial
7f70b78f32 Add missing encoding table entries for Koch arm (#1534) 2025-08-20 17:24:05 +02:00
Michel Aractingi
db36f01e8b add update_chunk_settings method for LeRobotDatasetMetadata. Introduce tests for chunk settings updates and validation of parameters. 2025-08-18 00:00:23 +02:00
Steven Palma
55198de096 fix(ci): rename libegl1-mesa in deb13 trixie (#1735) 2025-08-14 11:12:06 +02:00
Michel Aractingi
c7a3b02625 fixed tensor indicies in _check_cached_episode_sufficient in lerobot_dataset.py, added test 2025-08-13 16:16:32 +02:00
Michel Aractingi
267a753eda Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-08-13 01:39:32 +02:00
Michel Aractingi
4048b02d4a improved typing in datasets/utils.py 2025-07-31 14:32:29 +02:00
pre-commit-ci[bot]
f94092c169 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-07-30 10:12:02 +00:00
Michel Aractingi
1c79e3dec1 Added mock context manager to tests in order to avoid calls to the hub for dummy datasets 2025-07-30 12:11:39 +02:00
Francesco Capuano
527ae8e557 Add variable-size test datasets (#1610)
* fix: dummy datasets can be written to multiple files in multiple folders based on arbitrary data size

* fix: writing atomic episodes to multiple files (maybe)

* fix: moving unused write dataset function to test code
2025-07-30 11:26:28 +02:00
Michel Aractingi
890b1e473d Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-30 00:43:53 +02:00
Michel Aractingi
6447352439 added a check for comparing cached episodes in order to trigger a new download if the requested episodes dont match the cached ones 2025-07-30 00:32:28 +02:00
Michel Aractingi
788544d936 update lerobot_dataset docstring 2025-07-30 00:12:23 +02:00
Michel Aractingi
59d108a807 fix(convert_v2_v3) reverted concat data files from previous commit
fixed bug in meta data related chunk_index and file_index when concatenating video files, added clearer condition to respect conditions so that episode doesnt span multiple videos
2025-07-29 22:58:24 +02:00
Michel Aractingi
218ebed3ef feat(convert_dataset_v21_to_v3) added the use of more efficient Dataset.from_parquet and concatenate_datasets 2025-07-22 17:27:41 +02:00
Michel Aractingi
670d7f485f Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-22 11:18:58 +02:00
Michel Aractingi
c993fea8ab Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-21 14:51:05 +02:00
Michel Aractingi
066b81aec2 moved concat_video function to video_utils, cleaned some code 2025-07-21 14:47:16 +02:00
Michel Aractingi
dcb02a951d fix(convert_v1) use correct legacy path, remove comments from scripts, revert lekiwi/record.py to main 2025-07-21 11:49:15 +02:00
Michel Aractingi
ac0fd71f0a Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-21 10:21:24 +02:00
Michel Aractingi
f98f01e81d Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-20 01:40:03 +02:00
Michel Aractingi
23375cce3a fix(tests) bug in clear_episode_buffer 2025-07-20 01:39:19 +02:00
Michel Aractingi
8ffc00dbcd Removed batch_encoding_Size from record.py 2025-07-18 17:56:42 +02:00
Michel Aractingi
ec40fc41b5 Removed references to batch encoding to be added later or in another PR 2025-07-18 16:52:47 +02:00
Michel Aractingi
5ec70f704e removed check_timestamps_sync that is no longer used in the code,
removed tests in datasets related to check_timestamps_sync
added the use of `clear_episode_buffer` that was not used in `save_episode`
added the creation of the codebase_version tag that was missing in `slurm_upload`
2025-07-18 16:33:20 +02:00
Michel Aractingi
4c0ac93eb6 nit 2025-07-18 16:33:20 +02:00
pre-commit-ci[bot]
788dde3a34 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-07-18 16:33:20 +02:00
Michel Aractingi
e05d22cb7b Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-07-18 16:33:18 +02:00
Michel Aractingi
3483e4441e Removed examples from import path in port_datasets
removed readme from droid examples and add a tutorial in docs
2025-07-15 21:38:18 +02:00
Michel Aractingi
2a76135b82 Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-08 17:45:20 +02:00
Michel Aractingi
6a9834e8b6 Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-08 14:21:12 +02:00
Michel Aractingi
a4d3a414ca Added Francescos PRs for fixing aggregate.py 2025-07-08 14:17:01 +02:00
fracapuano
a49760e2ba fix: tests depending on various sizes, and duration is updated 2025-07-08 13:43:19 +02:00
fracapuano
4e01f87a6e add: tests forcing new file creation 2025-07-08 13:38:01 +02:00
Michel Aractingi
c8a5df963b partial fix html visualization tool: Added start_time and end_time keys 2025-07-08 00:17:00 +02:00
Michel Aractingi
18209e6194 Added the use of aggregate.py in slurm_aggregate_shards.py 2025-07-07 13:51:08 +02:00
Michel Aractingi
4a466d94b6 moved legacy functions to convert_stats.py 2025-07-06 22:32:51 +02:00
Michel Aractingi
9287c36f37 - Added missing license in the new scripts
- Added back legacy functions in conversion script of v2 to v21
 - Updated README description for dataset_v3
2025-07-06 22:29:05 +02:00
Michel Aractingi
30ffa259b7 Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3 2025-07-06 12:30:36 +02:00
Michel Aractingi
bee74c3eab Fix(tests) fix task index error in test_policies 2025-07-06 10:03:19 +02:00
Michel Aractingi
83bf24cc9a fix(tests) add features argument to load_nested_dataset 2025-07-05 10:16:29 +02:00
Michel Aractingi
3dbc3e60fb Added docstrings to aggregate, fix test_policies.py 2025-07-04 11:27:00 +02:00
Michel Aractingi
830a3b9f27 Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-07-02 18:22:59 +02:00
Michel Aractingi
69b1f7b118 nit precommit 2025-07-02 18:20:01 +02:00
Michel Aractingi
66454a0fbf Remove more references to lerobot.common 2025-07-02 18:18:19 +02:00
Michel Aractingi
012d428f7b Reverted back missing files in src/lerobot/configs/ 2025-07-02 17:33:51 +02:00
Michel Aractingi
1c17419224 Reverted back files that were changed during the rebase 2025-07-02 17:26:34 +02:00
Michel Aractingi
9dde8829e6 style nit 2025-07-02 17:10:56 +02:00
Michel Aractingi
0f66bbe2f9 Migrate PR to new folder structure introduce on 1417 2025-07-02 17:10:26 +02:00
pre-commit-ci[bot]
6de5670912 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-07-02 11:52:34 +02:00
Michel Aractingi
5e39b4ce94 fix(tests)
- Updated `lerobot_dataset.py:add_frame` to take task as key in frame
- Updated `lerobot_dataset.py` to remove robot argument from `create` function of lerobotdataset and lerobotdatasetmetadata and directly take the features
- Update `test_datasets.py` to features from Mock robot
- Update all the usage of `add_frame` in the library
- Update `dataset_factories.py`; had issues with new argument order
- Raise ValueError when no task is provided (in `datasets/utils.py` validate func)
2025-07-02 11:51:56 +02:00
Michel Aractingi
0a1da47527 fix(precommit) solve precommit issues 2025-07-02 11:51:06 +02:00
Michel Aractingi
6b482a93d6 fix(rebase) deleting media related to tutorials 2025-07-02 11:47:09 +02:00
Michel Aractingi
d9b9cc80da fix(rebase) reverting files to main 2025-07-02 11:47:07 +02:00
Michel Aractingi
c3e98db37d add missing files for porting agibot 2025-07-02 11:46:45 +02:00
fracapuano
01d0b7b102 fix: modularize tests to improve readability 2025-07-02 11:45:29 +02:00
fracapuano
848a494ff6 add: tests for aggregation code 2025-07-02 11:45:29 +02:00
fracapuano
378c147be6 fix: debug aggregation code 2025-07-02 11:45:27 +02:00
fracapuano
d4fbf6ef39 add: support for videos generation in datasets 2025-07-02 11:45:11 +02:00
Remi Cadene
8c1503dafa WIP after Francesco discussion 2025-07-02 11:45:11 +02:00
Remi Cadene
ba022dd091 Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3 2025-07-02 11:44:49 +02:00
Remi Cadene
13a1f68b8e WIP aggregate 2025-07-02 11:44:29 +02:00
Remi Cadene
58795d72c8 In tests: Add use_videos=False by default, Create mp4 file if True, then fix test_datasets and test_aggregate (all passing) 2025-07-02 11:44:21 +02:00
Remi Cadene
220997ff47 Fix visualize_dataset with rerun 2025-07-02 11:44:10 +02:00
Remi Cadene
ee2566456a Uploaded droid 1.0.1 2025-07-02 11:44:08 +02:00
Remi Cadene
a231930044 Fix aggregate (num_frames, dataset_from_index, index) 2025-07-02 11:43:46 +02:00
Remi Cadene
6f0fc7f386 Aggregate: Add concatenation 2025-07-02 11:43:36 +02:00
Remi Cadene
fde67dbae7 Fix convert v30 with image datasets 2025-07-02 11:43:35 +02:00
Remi Cadene
ad1ad11eac fix hf_dataset.set_transform(hf_transform_to_torch) 2025-07-02 11:43:33 +02:00
Remi Cadene
01bc89b6f4 Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3 2025-07-02 11:43:24 +02:00
Remi Cadene
8c43b3d05e Faster self.meta.episodes[...]
switch back to set_transform instead of set_format

Add video_files_size_in_mb

pre-commit run --all-files
2025-07-02 11:43:22 +02:00
Remi Cadene
d4af22418b Fix unit tests 2025-07-02 11:42:52 +02:00
Remi Cadene
eaec52a7b7 Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3 2025-07-02 11:42:49 +02:00
Remi Cadene
0a390de361 Merge remote-tracking branch 'origin/main' into user/rcadene/2025_04_11_dataset_v3 2025-07-02 11:41:53 +02:00
Remi Cadene
20b74ae1eb fix 2025-04-21 13:38:29 +00:00
Remi Cadene
b9b880bd8b fix get_parquet_file_size_in_mb + DEFAULT_FILE_SIZE_IN_MB=100 2025-04-21 12:59:35 +00:00
Remi Cadene
2866d0770f small fix ffmpeg encoding 2025-04-21 10:59:06 +02:00
Remi Cadene
4375a05a9f Add push to hub for convert_dataset_v21_to_v30 2025-04-21 10:08:25 +02:00
Remi Cadene
4acf99f622 pre-commit run --all-files 2025-04-21 09:34:19 +02:00
Remi Cadene
5a6ea09248 Rename tests/test_aggregate_datasets.py -> tests/datasets/test_aggregate.py 2025-04-19 19:30:28 +05:30
Remi Cadene
9c0836c8d0 Remove legacy from datasets/utils.py 2025-04-19 19:27:14 +05:30
Remi Cadene
b0cca75e5e Progress on aggregate_datasets 2025-04-19 19:11:53 +05:30
Remi Cadene
54b5c805bf Revert mistake convert_dataset_v20_to_v21.py 2025-04-17 04:47:00 +02:00
Remi Cadene
eab5543750 Merge (No verify) 2025-04-17 04:46:09 +02:00
Remi Cadene
6b6a990f4c most unit tests passing (TODO: convert datasets) 2025-04-16 21:30:58 +02:00
Remi Cadene
c2a05a1fde Fix (Now loading all frames is possible) 2025-04-14 14:47:18 +00:00
Remi Cadene
6c4d122198 fix joints 2025-04-11 15:01:03 +02:00
Remi Cadene
34c5d4ce07 Most unit tests are passing 2025-04-11 14:04:22 +02:00
Remi Cadene
c1b28f0b58 Commit before episodes episodes_stats merging 2025-04-09 15:20:15 +02:00
Remi Cadene
53ecec5fb2 WIP v21 to v30 2025-03-31 07:38:01 +00:00
Remi Cadene
65738f0a80 Improve slurm droid 2025-03-20 14:12:46 +00:00
Remi Cadene
5d184a7811 NIT 2025-03-18 16:55:08 +00:00
Remi Cadene
1a5c1ef9c7 Rename openx to droid + Improve all (not tested) 2025-03-18 16:28:09 +00:00
Remi Cadene
7866c1f7d1 Merge remote-tracking branch 'origin/main' into user/rcadene/2025_02_19_port_openx 2025-03-01 19:17:18 +00:00
Remi Cadene
3666ac9346 WIP UploadDataset 2025-03-01 19:07:22 +00:00
Remi Cadene
3daab2acbb Add upload_large_folder 2025-02-23 18:19:12 +00:00
Remi Cadene
c36d2253d0 Aggregate works 2025-02-23 18:18:46 +00:00
Remi Cadene
e2e6f6e666 Add auto_downsample_height_width 2025-02-23 18:15:39 +00:00
Remi Cadene
ff0029f84b aggregate works 2025-02-22 15:33:47 +00:00
Remi Cadene
39ad2d16d4 let's go 2025-02-22 11:12:39 +00:00
Remi Cadene
689c5efc72 optimize shard 2025-02-22 10:13:09 +00:00
Remi Cadene
eda0b996cd new dir 2025-02-21 23:56:44 +00:00
Remi Cadene
15e7a9d541 before new launch from scratch 2025-02-21 23:14:22 +00:00
Remi Cadene
52fb4143b5 workers 2025-02-21 13:08:21 +00:00
Remi Cadene
93c80b2cb1 rm brake 2025-02-20 23:24:03 +00:00
Remi Cadene
5fbbaa1bc0 fix No such file or directory error 2025-02-20 23:04:58 +00:00
Remi Cadene
71d1f5e2c9 WIP 2025-02-20 23:04:31 +00:00
Remi Cadene
b520941cd9 Merge remote-tracking branch 'origin/user/aliberts/2025_02_10_dataset_v2.1' into user/rcadene/2025_02_19_port_openx 2025-02-20 17:34:13 +00:00
Simon Alibert
64ed5258e6 Fix batch convert 2025-02-20 09:00:14 +01:00
Simon Alibert
392a8c32a7 Improve doc 2025-02-20 08:24:41 +01:00
Simon Alibert
969ef745a2 Remove dataset consolidate (#752) 2025-02-19 16:02:54 +01:00
Simon Alibert
6fe42a72db Add tag 2025-02-19 15:01:44 +01:00
Simon Alibert
2487228ea7 Use HF_HOME env variable (#753) 2025-02-19 14:49:46 +01:00
Remi Cadene
76436ca1de Merge remote-tracking branch 'tavish9_lerobot_openx/main' into user/rcadene/2025_02_19_port_openx 2025-02-19 12:58:18 +00:00
Simon Alibert
fbf2f2222a Remove local_files_only and use codebase_version instead of branches (#734) 2025-02-19 08:36:32 +01:00
Tavish
02bc4e03e0 support openx/rlds to lerobot 2025-02-18 22:25:58 +08:00
Simon Alibert
624eaf1175 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1 2025-02-17 12:06:05 +01:00
Simon Alibert
aed3eb4a94 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1 2025-02-15 15:56:24 +01:00
Simon Alibert
8426c64f42 Per-episode stats (#521)
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
Co-authored-by: Remi <remi.cadene@huggingface.co>
2025-02-15 15:47:16 +01:00
Remi
7c2bbee613 Validate features during add_frame + Add 2D-to-5D + Add string (#720) 2025-02-14 19:59:48 +01:00
Remi
9d6886dd08 Add frame level task (#693)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-02-14 14:22:22 +01:00
Simon Alibert
d67ca342e9 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1 2025-02-11 17:17:39 +01:00
Simon Alibert
57c9c21c39 Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1 2025-02-10 17:22:57 +01:00
Simon Alibert
38c14571cc Bump CODEBASE_VERSION 2025-02-10 16:39:34 +01:00
165 changed files with 10518 additions and 18437 deletions

View File

@@ -233,7 +233,7 @@ Under the hood, the `LeRobotDataset` format makes use of several ways to seriali
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
```
````
dataset attributes:
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
│ ├ observation.images.cam_high (VideoFrame):
@@ -246,20 +246,30 @@ dataset attributes:
│ ├ timestamp (float32): timestamp in the episode
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
│ └ index (int64): general index in the whole dataset
episode_data_index: contains 2 tensors with the start and end indices of each episode
│ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0
└ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,)
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
...
├ info: a dictionary of metadata on the dataset
│ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
│ ├ fps (float): frame per second the dataset is recorded/synchronized to
video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
encoding (dict): if video, this documents the main options that were used with ffmpeg to encode the videos
├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
└ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
```
meta: a LeRobotDatasetMetadata object containing:
│ ├ info: a dictionary of metadata on the dataset
│ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
│ │ ├ fps (int): frame per second the dataset is recorded/synchronized to
│ ├ features (dict): all features contained in the dataset with their shapes and types
│ ├ total_episodes (int): total number of episodes in the dataset
│ │ ├ total_frames (int): total number of frames in the dataset
│ ├ robot_type (str): robot type used for recording
│ ├ data_path (str): formattable string for the parquet files
video_path (str): formattable string for the video files (if using videos)
episodes: a DataFrame containing episode metadata with columns:
│ │ ├ episode_index (int): index of the episode
│ │ ├ tasks (list): list of tasks for this episode
│ │ ├ length (int): number of frames in this episode
│ │ ├ dataset_from_index (int): start index of this episode in the dataset
│ │ └ dataset_to_index (int): end index of this episode in the dataset
│ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
│ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
│ │ └ ...
│ └ tasks: a DataFrame containing task information with task names as index and task_index as values
├ root (Path): local directory where the dataset is stored
├ image_transforms (Callable): optional image transformations to apply to visual modalities
└ delta_timestamps (dict): optional delta timestamps for temporal queries
decoding videos (e.g., 'pyav', 'torchcodec')
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
@@ -283,7 +293,7 @@ lerobot-eval \
--eval.n_episodes=10 \
--policy.use_amp=false \
--policy.device=cuda
```
````
Note: After training your own policy, you can re-evaluate the checkpoints with:

View File

@@ -108,7 +108,8 @@ def save_decoded_frames(
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
ep_num_images = dataset.episode_data_index["to"][0].item()
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
return
@@ -265,7 +266,8 @@ def benchmark_encoding_decoding(
overwrite=True,
)
ep_num_images = dataset.episode_data_index["to"][0].item()
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size

View File

@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Install system dependencies and uv (as root)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& mv /root/.local/bin/uv /usr/local/bin/uv \

View File

@@ -19,6 +19,8 @@
title: Train RL in Simulation
- local: async
title: Use Async Inference
- local: porting_datasets_v3
title: Porting Large Datasets
title: "Tutorials"
- sections:
- local: smolvla

View File

@@ -4,13 +4,7 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
It combines three key ingredients:
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
@@ -62,243 +56,30 @@ pip install -e ".[hilserl]"
### Understanding Configuration
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
<!-- prettier-ignore-start -->
```python
class GymManipulatorConfig:
env: HILSerlRobotEnvConfig # Environment configuration (nested)
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
mode: str | None = None # "record", "replay", or None (for training)
device: str = "cpu" # Compute device
class HILSerlRobotEnvConfig(EnvConfig):
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
name: str = "real_robot" # Environment name
task: str | None = None # Task identifier
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
fps: int = 10 # Control frequency
# Nested processor configuration
class HILSerlProcessorConfig:
control_mode: str = "gamepad" # Control mode
observation: ObservationConfig | None = None # Observation processing settings
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
gripper: GripperConfig | None = None # Gripper control and penalty settings
reset: ResetConfig | None = None # Environment reset and timing settings
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
max_gripper_pos: float | None = 100.0 # Maximum gripper position
# Sub-configuration classes
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
resize_size: tuple[int, int] | None = None # Target image size
class GripperConfig:
use_gripper: bool = True # Enable gripper control
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
class ResetConfig:
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
reset_time_s: float = 5.0 # Time to wait during reset
control_time_s: float = 20.0 # Maximum episode duration
terminate_on_success: bool = True # Whether to terminate episodes on success detection
class InverseKinematicsConfig:
urdf_path: str | None = None # Path to robot URDF file
target_frame_name: str | None = None # End-effector frame name
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
class RewardClassifierConfig:
pretrained_path: str | None = None # Path to pretrained reward classifier
success_threshold: float = 0.5 # Success detection threshold
success_reward: float = 1.0 # Reward value for successful episodes
# Dataset configuration
class DatasetConfig:
repo_id: str # LeRobot dataset repository ID
dataset_root: str # Local dataset root directory
task: str # Task identifier
num_episodes: int # Number of episodes for recording
episode: int # Episode index for replay
push_to_hub: bool # Whether to push datasets to Hub
name: str = "real_robot" # Environment name
mode: str = None # "record", "replay", or None (for training)
repo_id: str | None = None # LeRobot dataset repository ID
dataset_root: str | None = None # Local dataset root (optional)
task: str = "" # Task identifier
num_episodes: int = 10 # Number of episodes for recording
episode: int = 0 # episode index for replay
device: str = "cuda" # Compute device
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
pretrained_policy_name_or_path: str | None = None # For policy loading
reward_classifier_pretrained_path: str | None = None # For reward model
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
```
<!-- prettier-ignore-end -->
### Processor Pipeline Architecture
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
#### Environment Processor Pipeline
The environment processor (`env_processor`) handles incoming observations and environment state:
1. **VanillaObservationProcessor**: Converts raw robot observations into standardized format
2. **JointVelocityProcessor** (optional): Adds joint velocity information to observations
3. **MotorCurrentProcessor** (optional): Adds motor current readings to observations
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
5. **ImageCropResizeProcessor** (optional): Crops and resizes camera images
6. **TimeLimitProcessor** (optional): Enforces episode time limits
7. **GripperPenaltyProcessor** (optional): Applies penalties for inappropriate gripper usage
8. **RewardClassifierProcessor** (optional): Automated reward detection using vision models
9. **ToBatchProcessor**: Converts data to batch format for neural network processing
10. **DeviceProcessor**: Moves data to the specified compute device (CPU/GPU)
#### Action Processor Pipeline
The action processor (`action_processor`) handles outgoing actions and human interventions:
1. **AddTeleopActionAsComplimentaryData**: Captures teleoperator actions for logging
2. **AddTeleopEventsAsInfo**: Records intervention events and episode control signals
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
4. **InterventionActionProcessor**: Handles human interventions and episode termination
5. **Inverse Kinematics Pipeline** (when enabled):
- **MapDeltaActionToRobotAction**: Converts delta actions to robot action format
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
- **EEBoundsAndSafety**: Enforces workspace safety bounds
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
- **GripperVelocityToJoint**: Handles gripper control commands
#### Configuration Examples
**Basic Observation Processing**:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": false,
"display_cameras": false
}
}
}
}
```
**Image Processing**:
```json
{
"env": {
"processor": {
"image_preprocessing": {
"crop_params_dict": {
"observation.images.front": [180, 250, 120, 150],
"observation.images.side": [180, 207, 180, 200]
},
"resize_size": [128, 128]
}
}
}
}
```
**Inverse Kinematics Setup**:
```json
{
"env": {
"processor": {
"inverse_kinematics": {
"urdf_path": "path/to/robot.urdf",
"target_frame_name": "end_effector",
"end_effector_bounds": {
"min": [0.16, -0.08, 0.03],
"max": [0.24, 0.2, 0.1]
},
"end_effector_step_sizes": {
"x": 0.02,
"y": 0.02,
"z": 0.02
}
}
}
}
}
```
### Advanced Observation Processing
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
#### Joint Velocity Processing
Enable joint velocity estimation to provide the policy with motion information:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true
}
}
}
}
```
This processor:
- Estimates joint velocities using finite differences between consecutive joint position readings
- Adds velocity information to the observation state vector
- Useful for policies that need motion awareness for dynamic tasks
#### Motor Current Processing
Monitor motor currents to detect contact forces and load conditions:
```json
{
"env": {
"processor": {
"observation": {
"add_current_to_observation": true
}
}
}
}
```
This processor:
- Reads motor current values from the robot's control system
- Adds current measurements to the observation state vector
- Helps detect contact events, object weights, and mechanical resistance
- Useful for contact-rich manipulation tasks
#### Combined Observation Processing
You can enable multiple observation processing features simultaneously:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": true,
"add_ee_pose_to_observation": false,
"display_cameras": false
}
}
}
}
```
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
### Finding Robot Workspace Bounds
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
@@ -349,56 +130,22 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
1. Set `mode` to `"record"` at the root level
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
3. Set `num_episodes` in the `dataset` section to the number of demonstrations you want to collect
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
1. Set `mode` to `"record"`
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
3. Set `num_episodes` to the number of demonstrations you want to collect
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
5. Configure `robot`, `cameras`, and other hardware settings
Example configuration section:
```json
{
"env": {
"type": "gym_manipulator",
"name": "real_robot",
"fps": 10,
"processor": {
"control_mode": "gamepad",
"observation": {
"display_cameras": false
},
"image_preprocessing": {
"crop_params_dict": {},
"resize_size": [128, 128]
},
"gripper": {
"use_gripper": true,
"gripper_penalty": 0.0
},
"reset": {
"reset_time_s": 5.0,
"control_time_s": 20.0
}
},
"robot": {
// ... robot configuration ...
},
"teleop": {
// ... teleoperator configuration ...
}
},
"dataset": {
"repo_id": "username/pick_lift_cube",
"dataset_root": null,
"task": "pick_and_lift",
"num_episodes": 15,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
"device": "cpu"
}
"mode": "record",
"repo_id": "username/pick_lift_cube",
"dataset_root": null,
"task": "pick_and_lift",
"num_episodes": 15,
"episode": 0,
"push_to_hub": true
```
### Using a Teleoperation Device
@@ -444,20 +191,10 @@ The gamepad provides a very convenient way to control the robot and the episode
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
```json
{
"env": {
"teleop": {
"type": "gamepad",
"use_gripper": true
},
"processor": {
"control_mode": "gamepad",
"gripper": {
"type": "gamepad",
"use_gripper": true
}
}
}
}
},
```
<p align="center">
@@ -479,21 +216,11 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
```json
{
"env": {
"teleop": {
"type": "so101_leader",
"port": "/dev/tty.usbmodem585A0077921",
"use_degrees": true
"type": "so101_leader",
"port": "/dev/tty.usbmodem585A0077921", # check your port number
"use_degrees": true
},
"processor": {
"control_mode": "leader",
"gripper": {
"use_gripper": true
}
}
}
}
```
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
@@ -524,7 +251,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
During recording:
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
2. Complete the task successfully
3. The episode ends with a reward of 1 when you press the "success" button
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
@@ -583,19 +310,11 @@ observation.images.front: [180, 250, 120, 150]
Add these crop parameters to your training configuration:
```json
{
"env": {
"processor": {
"image_preprocessing": {
"crop_params_dict": {
"observation.images.side": [180, 207, 180, 200],
"observation.images.front": [180, 250, 120, 150]
},
"resize_size": [128, 128]
}
}
}
}
"crop_params_dict": {
"observation.images.side": [180, 207, 180, 200],
"observation.images.front": [180, 250, 120, 150]
},
"resize_size": [128, 128]
```
**Recommended image resolution**
@@ -624,52 +343,26 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
**Key Parameters for Data Collection**
- **mode**: set it to `"record"` to collect a dataset (at root level)
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
- **dataset.num_episodes**: Number of episodes to record
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
- **env.fps**: Number of frames per second to record
- **dataset.push_to_hub**: Whether to push the dataset to the hub
- **mode**: set it to `"record"` to collect a dataset
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
- **num_episodes**: Number of episodes to record
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
- **fps**: Number of frames per second to record
- **push_to_hub**: Whether to push the dataset to the hub
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
Example configuration section for data collection:
```json
{
"env": {
"type": "gym_manipulator",
"name": "real_robot",
"fps": 10,
"processor": {
"reset": {
"reset_time_s": 5.0,
"control_time_s": 20.0,
"terminate_on_success": false
},
"gripper": {
"use_gripper": true
}
},
"robot": {
// ... robot configuration ...
},
"teleop": {
// ... teleoperator configuration ...
}
},
"dataset": {
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"task": "reward_classifier_task",
"num_episodes": 20,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
"device": "cpu"
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"num_episodes": 20,
"push_to_hub": true,
"fps": 10,
"number_of_steps_after_success": 15
}
```
@@ -728,17 +421,9 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
<!-- prettier-ignore-start -->
```python
config = GymManipulatorConfig(
env=HILSerlRobotEnvConfig(
processor=HILSerlProcessorConfig(
reward_classifier=RewardClassifierConfig(
pretrained_path="path_to_your_pretrained_trained_model"
)
),
# Other environment parameters
),
dataset=DatasetConfig(...),
mode=None # For training
env_config = HILSerlRobotEnvConfig(
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
# Other environment parameters
)
```
<!-- prettier-ignore-end -->
@@ -747,18 +432,7 @@ or set the argument in the json config file.
```json
{
"env": {
"processor": {
"reward_classifier": {
"pretrained_path": "path_to_your_pretrained_model",
"success_threshold": 0.7,
"success_reward": 1.0
},
"reset": {
"terminate_on_success": true
}
}
}
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
}
```

View File

@@ -32,12 +32,9 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"type": "hil",
"name": "franka_sim",
"task": "PandaPickCubeGamepad-v0",
"device": "cuda"
}
```
@@ -48,40 +45,28 @@ Available tasks:
- `PandaPickCubeGamepad-v0`: With gamepad control
- `PandaPickCubeKeyboard-v0`: With keyboard control
### Processor Configuration
### Gym Wrappers Configuration
```json
{
"env": {
"processor": {
"control_mode": "gamepad",
"gripper": {
"use_gripper": true,
"gripper_penalty": -0.02
},
"reset": {
"control_time_s": 15.0,
"fixed_reset_joint_positions": [
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
]
},
"inverse_kinematics": {
"end_effector_step_sizes": {
"x": 0.025,
"y": 0.025,
"z": 0.025
}
}
"wrapper": {
"gripper_penalty": -0.02,
"control_time_s": 15.0,
"use_gripper": true,
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
"end_effector_step_sizes": {
"x": 0.025,
"y": 0.025,
"z": 0.025
},
"control_mode": "gamepad"
}
}
}
```
Important parameters:
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
- `gripper.use_gripper`: Whether to enable gripper control
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
- `gripper_penalty`: Penalty for excessive gripper movement
- `use_gripper`: Whether to enable gripper control
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
## Running with HIL RL of LeRobot
@@ -90,50 +75,39 @@ Important parameters:
To run the environment, set mode to null:
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
<!-- prettier-ignore-end -->
### Recording a Dataset
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0"
},
"dataset": {
"repo_id": "username/sim_dataset",
"dataset_root": null,
"task": "pick_cube",
"num_episodes": 10,
"episode": 0,
"push_to_hub": true
},
"mode": "record"
}
```
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
<!-- prettier-ignore-end -->
### Training a Policy
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
```
<!-- prettier-ignore-end -->
In a different terminal, run the learner server:
```bash
<!-- prettier-ignore-start -->
```python
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
```
<!-- prettier-ignore-end -->
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.

View File

@@ -519,14 +519,11 @@ from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.record import record_loop
from lerobot.policies.factory import make_processor
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
@@ -538,7 +535,7 @@ robot_config = SO100FollowerConfig(
robot = SO100Follower(robot_config)
# Initialize the policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
@@ -547,7 +544,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
repo_id="<hf_username>/eval_<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -562,12 +559,6 @@ _init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
@@ -577,8 +568,6 @@ for episode_idx in range(NUM_EPISODES):
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,

View File

@@ -24,36 +24,11 @@ pip install -e ".[hilserl]"
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_gym",
"dataset_root": null,
"task": "pick_cube",
"num_episodes": 30,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
"device": "cuda"
}
```
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
Key configuration points:
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
- Set `num_episodes: 30` to collect 30 demonstration episodes
- Ensure `mode` is set to `"record"`
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
Then we can run this command to start:
@@ -165,32 +140,9 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
## Evaluate your policy in Sim
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
Here's an example evaluation configuration:
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_sim_dataset",
"dataset_root": null,
"task": "pick_cube"
},
"pretrained_policy_name_or_path": "your_username/il_sim_model",
"device": "cuda"
}
```
Make sure to replace:
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
Then you can run this command to visualize your trained policy

View File

@@ -0,0 +1,321 @@
# Porting Large Datasets to LeRobot Dataset v3.0
This tutorial explains how to port large-scale robotic datasets to the LeRobot Dataset v3.0 format. We'll use the **DROID 1.0.1** dataset as our primary example, which demonstrates handling multi-terabyte datasets with thousands of shards across SLURM clusters.
## File Organization: v2.1 vs v3.0
Dataset v3.0 fundamentally changes how data is organized and stored:
**v2.1 Structure (Episode-based)**:
```
dataset/
├── data/chunk-000/episode_000000.parquet
├── data/chunk-000/episode_000001.parquet
├── videos/chunk-000/camera/episode_000000.mp4
└── meta/episodes.jsonl
```
**v3.0 Structure (File-based)**:
```
dataset/
├── data/chunk-000/file-000.parquet # Multiple episodes per file
├── videos/camera/chunk-000/file-000.mp4 # Consolidated video chunks
└── meta/episodes/chunk-000/file-000.parquet # Structured metadata
```
This transition from individual episode files to file-based chunks dramatically improves performance and reduces storage overhead.
## What's New in Dataset v3.0
Dataset v3.0 introduces significant improvements for handling large datasets:
### 🏗️ **Enhanced File Organization**
- **File-based structure**: Episodes are now grouped into chunked files rather than individual episode files
- **Configurable file sizes**: for data and video files
- **Improved storage efficiency**: Better compression and reduced overhead
### 📊 **Modern Metadata Management**
- **Parquet-based metadata**: Replaced JSON Lines with efficient parquet format
- **Structured episode access**: Direct pandas DataFrame access via `dataset.meta.episodes`
- **Per-episode statistics**: Enhanced statistics tracking at episode level
### 🚀 **Performance Enhancements**
- **Memory-mapped access**: Improved RAM usage through PyArrow memory mapping
- **Faster loading**: Significantly reduced dataset initialization time
- **Better scalability**: Designed for datasets with millions of episodes
## Prerequisites
Before porting large datasets, ensure you have:
- **LeRobot installed** with v3.0 support. Follow our [Installation Guide](./installation).
- **Sufficient storage**: Raw datasets can be very large (e.g., DROID requires 2TB)
- **Cluster access** (recommended for large datasets): SLURM or similar job scheduler
- **Dataset-specific dependencies**: For DROID, you'll need TensorFlow Dataset utilities
## Understanding the DROID Dataset
[DROID 1.0.1](https://droid-dataset.github.io/droid/the-droid-dataset) is an excellent example of a large-scale robotic dataset:
- **Size**: 1.7TB (RLDS format), 8.7TB (raw data)
- **Structure**: 2048 pre-defined TensorFlow dataset shards
- **Content**: 76,000+ robot manipulation trajectories from Franka Emika Panda robots
- **Scope**: Real-world manipulation tasks across multiple environments and objects
- **Format**: Originally in TensorFlow Records/RLDS format, requiring conversion to LeRobot format
- **Hosting**: Google Cloud Storage with public access via `gsutil`
The dataset contains diverse manipulation demonstrations with:
- Multiple camera views (wrist camera, exterior cameras)
- Natural language task descriptions
- Robot proprioceptive state and actions
- Success/failure annotations
### DROID Features Schema
```python
DROID_FEATURES = {
# Episode markers
"is_first": {"dtype": "bool", "shape": (1,)},
"is_last": {"dtype": "bool", "shape": (1,)},
"is_terminal": {"dtype": "bool", "shape": (1,)},
# Language instructions
"language_instruction": {"dtype": "string", "shape": (1,)},
"language_instruction_2": {"dtype": "string", "shape": (1,)},
"language_instruction_3": {"dtype": "string", "shape": (1,)},
# Robot state
"observation.state.gripper_position": {"dtype": "float32", "shape": (1,)},
"observation.state.cartesian_position": {"dtype": "float32", "shape": (6,)},
"observation.state.joint_position": {"dtype": "float32", "shape": (7,)},
# Camera observations
"observation.images.wrist_left": {"dtype": "image"},
"observation.images.exterior_1_left": {"dtype": "image"},
"observation.images.exterior_2_left": {"dtype": "image"},
# Actions
"action.gripper_position": {"dtype": "float32", "shape": (1,)},
"action.cartesian_position": {"dtype": "float32", "shape": (6,)},
"action.joint_position": {"dtype": "float32", "shape": (7,)},
# Standard LeRobot format
"observation.state": {"dtype": "float32", "shape": (8,)}, # joints + gripper
"action": {"dtype": "float32", "shape": (8,)}, # joints + gripper
}
```
## Approach 1: Single Computer Porting
### Step 1: Install Dependencies
For DROID specifically:
```bash
pip install tensorflow
pip install tensorflow_datasets
```
For other datasets, install the appropriate readers for your source format.
### Step 2: Download Raw Data
Download DROID from Google Cloud Storage using `gsutil`:
```bash
# Install Google Cloud SDK if not already installed
# https://cloud.google.com/sdk/docs/install
# Download the full RLDS dataset (1.7TB)
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 /your/data/
# Or download just the 100-episode sample (2GB) for testing
gsutil -m cp -r gs://gresearch/robotics/droid_100 /your/data/
```
> [!WARNING]
> Large datasets require substantial time and storage:
>
> - **Full DROID (1.7TB)**: Several days to download depending on bandwidth
> - **Processing time**: 7+ days for local porting of full dataset
> - **Upload time**: 3+ days to push to Hugging Face Hub
> - **Local storage**: ~400GB for processed LeRobot format
### Step 3: Port the Dataset
```bash
python examples/port_datasets/port_droid_rlds.py \
--raw-dir /your/data/droid/1.0.1 \
--repo-id your_id/droid_1.0.1 \
--push-to-hub
```
### Development and Testing
For development, you can port a single shard:
```bash
python examples/port_datasets/port_droid_rlds.py \
--raw-dir /your/data/droid/1.0.1 \
--repo-id your_id/droid_1.0.1_test \
--num-shards 2048 \
--shard-index 0
```
This approach works for smaller datasets or testing, but large datasets require cluster computing.
## Approach 2: SLURM Cluster Porting (Recommended)
For large datasets like DROID, parallel processing across multiple nodes dramatically reduces processing time.
### Step 1: Install Cluster Dependencies
```bash
pip install datatrove # Hugging Face's distributed processing library
```
### Step 2: Configure Your SLURM Environment
Find your partition information:
```bash
sinfo --format="%R" # List available partitions
sinfo -N -p your_partition -h -o "%N cpus=%c mem=%m" # Check resources
```
Choose a **CPU partition** - no GPU needed for dataset porting.
### Step 3: Launch Parallel Porting Jobs
```bash
python examples/port_datasets/slurm_port_shards.py \
--raw-dir /your/data/droid/1.0.1 \
--repo-id your_id/droid_1.0.1 \
--logs-dir /your/logs \
--job-name port_droid \
--partition your_partition \
--workers 2048 \
--cpus-per-task 8 \
--mem-per-cpu 1950M
```
#### Parameter Guidelines
- **`--workers`**: Number of parallel jobs (max 2048 for DROID's shard count)
- **`--cpus-per-task`**: 8 CPUs recommended for frame encoding parallelization
- **`--mem-per-cpu`**: ~16GB total RAM (8×1950M) for loading raw frames
> [!TIP]
> Start with fewer workers (e.g., 100) to test your cluster configuration before launching thousands of jobs.
### Step 4: Monitor Progress
Check running jobs:
```bash
squeue -u $USER
```
Monitor overall progress:
```bash
jobs_status /your/logs
```
Inspect individual job logs:
```bash
less /your/logs/port_droid/slurm_jobs/JOB_ID_WORKER_ID.out
```
Debug failed jobs:
```bash
failed_logs /your/logs/port_droid
```
### Step 5: Aggregate Shards
Once all porting jobs complete:
```bash
python examples/port_datasets/slurm_aggregate_shards.py \
--repo-id your_id/droid_1.0.1 \
--logs-dir /your/logs \
--job-name aggr_droid \
--partition your_partition \
--workers 2048 \
--cpus-per-task 8 \
--mem-per-cpu 1950M
```
### Step 6: Upload to Hub
```bash
python examples/port_datasets/slurm_upload.py \
--repo-id your_id/droid_1.0.1 \
--logs-dir /your/logs \
--job-name upload_droid \
--partition your_partition \
--workers 50 \
--cpus-per-task 4 \
--mem-per-cpu 1950M
```
> [!NOTE]
> Upload uses fewer workers (50) since it's network-bound rather than compute-bound.
## Dataset v3.0 File Structure
Your completed dataset will have this modern structure:
```
dataset/
├── meta/
│ ├── episodes/
│ │ └── chunk-000/
│ │ └── file-000.parquet # Episode metadata
│ ├── tasks.parquet # Task definitions
│ ├── stats.json # Aggregated statistics
│ └── info.json # Dataset information
├── data/
│ └── chunk-000/
│ └── file-000.parquet # Consolidated episode data
└── videos/
└── camera_key/
└── chunk-000/
└── file-000.mp4 # Consolidated video files
```
This replaces the old episode-per-file structure with efficient, optimally-sized chunks.
## Migrating from Dataset v2.1
If you have existing datasets in v2.1 format, use the migration tool:
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id your_id/existing_dataset
```
This automatically:
- Converts file structure to v3.0 format
- Migrates metadata from JSON Lines to parquet
- Aggregates statistics and creates per-episode stats
- Updates version information
## Performance Benefits
Dataset v3.0 provides significant improvements for large datasets:
- **Faster loading**: 3-5x reduction in initialization time
- **Memory efficiency**: Better RAM usage through memory mapping
- **Scalable processing**: Handles millions of episodes efficiently
- **Storage optimization**: Reduced file count and improved compression

View File

@@ -92,11 +92,11 @@ print(dataset.hf_dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
# frame indices associated to the first episode:
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]

View File

@@ -1,7 +1,6 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_processor
from lerobot.record import record_loop
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.utils.control_utils import init_keyboard_listener
@@ -12,14 +11,12 @@ NUM_EPISODES = 2
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
robot = LeKiwiClient(robot_config)
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
@@ -28,7 +25,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
repo_id="<hf_username>/<eval_dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -46,12 +43,6 @@ listener, events = init_keyboard_listener()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
@@ -62,8 +53,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,

View File

@@ -38,7 +38,7 @@ while True:
keyboard_keys = keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
log_rerun_data(observation=observation, action={**arm_action, **base_action})
log_rerun_data(observation, {**arm_action, **base_action})
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action

View File

@@ -1,158 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import merge_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_processor
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
)
from lerobot.processor.pipeline import RobotProcessor
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot with degrees
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Initialize the robot
robot = SO100Follower(robot_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessor(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=to_transition_robot_observation,
to_output=lambda tr: tr,
)
# Build dataset action and gripper features
action_ee_and_gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints,
initial_features={},
use_videos=True,
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
) # Get all ee action features + gripper pos action features
# Build dataset observation features
obs_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=robot.observation_features,
use_videos=True,
patterns=["observation.state.ee"],
) # Get all ee observation features
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
print("All dataset features: ", dataset_features)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording_phone")
# Connect the robot and teleoperator
robot.connect()
episode_idx = 0
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
robot_action_processor=robot_ee_to_joints,
robot_observation_processor=robot_joints_to_ee_pose,
)
dataset.save_episode()
# Clean up
log_say("Stop recording")
robot.disconnect()
dataset.push_to_hub()

View File

@@ -1,215 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import merge_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
)
from lerobot.processor.pipeline import RobotProcessor
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
EEBoundsAndSafety,
EEReferenceAndDelta,
ForwardKinematicsJointsToEE,
GripperVelocityToJoint,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone import Phone
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
NUM_EPISODES = 10
FPS = 30
EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 30
TASK_DESCRIPTION = "My task description"
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot and teleoperator
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose = RobotProcessor(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.20,
max_ee_twist_step_rad=0.50,
),
],
to_transition=to_transition_teleop_action,
to_output=lambda tr: tr,
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
GripperVelocityToJoint(
motor_names=list(robot.bus.motors.keys()),
speed_factor=20.0,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessor(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=to_transition_robot_observation,
to_output=lambda tr: tr,
)
# Build dataset ee action features
action_ee = aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose,
initial_features=phone.action_features,
use_videos=True,
patterns=["action.ee"],
)
# Get gripper pos action features
gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints,
initial_features={},
use_videos=True,
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
)
# Build dataset ee observation features
observation_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=robot.observation_features,
use_videos=True,
patterns=["observation.state.ee"],
)
dataset_features = merge_features(action_ee, gripper, observation_ee)
print("All dataset features: ", dataset_features)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording_phone")
# Connect the robot and teleoperator
robot.connect()
phone.connect()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose,
robot_action_processor=robot_ee_to_joints,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose,
robot_action_processor=robot_ee_to_joints,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
dataset.push_to_hub()

View File

@@ -1,106 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.converters import to_output_robot_action
from lerobot.processor.pipeline import RobotProcessor
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
)
robot = SO100Follower(robot_config)
robot.connect()
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
actions = dataset.hf_dataset.select_columns("action")
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# This method converts the action from the dataset to a transition for pipeline
def action_to_transition(action: dict):
act = {}
# EE pose
for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"):
if k in action:
act[f"action.{k}"] = float(action[k])
# Gripper: your dataset has absolute position
if "gripper.pos" in action:
act["action.gripper.pos"] = float(action["gripper.pos"])
return {
"observation": None,
"action": act,
"reward": None,
"done": False,
"truncated": False,
"info": {},
"complementary_data": {},
}
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=action_to_transition,
to_output=to_output_robot_action,
)
robot_ee_to_joints.reset()
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
ee_action = {
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
}
joint_action = robot_ee_to_joints(ee_action)
action_sent = robot.send_action(joint_action)
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
robot.disconnect()

View File

@@ -1,109 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specif
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
EEBoundsAndSafety,
EEReferenceAndDelta,
GripperVelocityToJoint,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone import Phone
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
# Initialize the robot and teleoperator
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop_device = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose = RobotProcessor(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
],
to_transition=to_transition_teleop_action,
to_output=lambda tr: tr,
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
),
GripperVelocityToJoint(
motor_names=list(robot.bus.motors.keys()),
speed_factor=20.0,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
)
robot.connect()
teleop_device.connect()
print("Starting teleop loop. Move your phone to teleoperate the robot.")
while True:
phone_obs = teleop_device.get_action()
if not phone_obs:
time.sleep(0.01)
continue
# Get teleop observation
phone_obs = teleop_device.get_action()
# Phone to EE pose transition
ee_transition = phone_to_robot_ee_pose(phone_obs)
# EE pose to Joints transition
joint_action = robot_ee_to_joints(ee_transition)
if joint_action:
robot.send_action(joint_action)
time.sleep(0.01)

View File

@@ -0,0 +1,47 @@
#!/bin/bash
# Example script for converting RT-1 dataset using SLURM
# Make sure to modify the paths and parameters according to your setup
# Configuration
RAW_DIR="/path/to/datasets/fractal20220817_data/0.1.0"
REPO_ID="your_username/rt1_lerobot"
LOGS_DIR="/path/to/logs"
PARTITION="cpu" # Your SLURM partition name
# Step 1: Convert dataset using distributed processing
echo "Starting RT-1 dataset conversion..."
python examples/port_datasets/slurm_port_shards.py \
--raw-dir "$RAW_DIR" \
--repo-id "$REPO_ID" \
--dataset-type rlds \
--logs-dir "$LOGS_DIR" \
--job-name rt1_conversion \
--workers 32 \
--num-shards 32 \
--partition "$PARTITION" \
--cpus-per-task 4 \
--mem-per-cpu 2G \
--slurm 1
# Step 2: Wait for jobs to complete (you can monitor with squeue)
echo "Conversion jobs submitted. Monitor with 'squeue -u \$USER'"
echo "Once all jobs complete, run the aggregation step:"
echo ""
echo "python examples/port_datasets/slurm_aggregate_shards.py \\"
echo " --repo-id $REPO_ID \\"
echo " --push-to-hub"
# Uncomment the following lines if you want to automatically aggregate
# (but make sure all shards are complete first)
# echo "Waiting for jobs to complete..."
# while [ $(squeue -u $USER -h | wc -l) -gt 0 ]; do
# echo "Jobs still running, waiting 60 seconds..."
# sleep 60
# done
# echo "All jobs completed. Starting aggregation..."
# python examples/port_datasets/slurm_aggregate_shards.py \
# --repo-id "$REPO_ID" \
# --push-to-hub

View File

@@ -0,0 +1,85 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from pathlib import Path
def find_missing_workers(completions_dir, world_size):
"""Find workers that are not completed and returns their indices."""
full = list(range(world_size))
completed = []
for path in completions_dir.glob("*"):
if path.name in [".", ".."]:
continue
index = path.name.lstrip("0")
index = 0 if index == "" else int(index)
completed.append(index)
missing_workers = set(full) - set(completed)
return missing_workers
def find_output_files(slurm_dir, worker_indices):
"""Find output files associated to worker indices, and return tuples
of (worker index, output file path)
"""
out_files = []
for path in slurm_dir.glob("*.out"):
_, worker_id = path.name.replace(".out", "").split("_")
worker_id = int(worker_id)
if worker_id in worker_indices:
out_files.append((worker_id, path))
return out_files
def display_error_files(logs_dir, job_name):
executor_path = Path(logs_dir) / job_name / "executor.json"
completions_dir = Path(logs_dir) / job_name / "completions"
with open(executor_path) as f:
executor = json.load(f)
missing_workers = find_missing_workers(completions_dir, executor["world_size"])
for missing in sorted(missing_workers)[::-1]:
print(missing)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--logs-dir",
type=str,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="port_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
args = parser.parse_args()
display_error_files(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_phone import PhoneConfig
from .phone import Phone
"""Open X-Embodiment utilities for dataset conversion."""

View File

@@ -0,0 +1,854 @@
"""
Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/configs.py
configs.py
Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment.
Configuration adopts the following structure:
image_obs_keys:
primary: primary external RGB
secondary: secondary external RGB
wrist: wrist RGB
depth_obs_keys:
primary: primary external depth
secondary: secondary external depth
wrist: wrist depth
# Always 8-dim =>> changes based on `StateEncoding`
state_obs_keys:
StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
StateEncoding.JOINT: Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
state_encoding: Type of `StateEncoding`
action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position)
"""
from enum import IntEnum
import tensorflow as tf
def zero_action_filter(traj: dict) -> bool:
"""
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
"""
DROID_Q01 = tf.convert_to_tensor( # NOQA: N806
[
-0.7776297926902771,
-0.5803514122962952,
-0.5795090794563293,
-0.6464047729969025,
-0.7041108310222626,
-0.8895104378461838,
]
)
DROID_Q99 = tf.convert_to_tensor( # NOQA: N806
[
0.7597932070493698,
0.5726242214441299,
0.7351000607013702,
0.6705610305070877,
0.6464948207139969,
0.8897542208433151,
]
)
DROID_NORM_0_ACT = ( # NOQA: N806
2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
)
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
# Defines Proprioceptive State Encoding Schemes
class StateEncoding(IntEnum):
# fmt: off
NONE = -1 # No Proprioceptive State
POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + <PAD> (1) + Gripper Open/Close (1)
POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
JOINT = 3 # Joint Angles (7, <PAD> if fewer) + Gripper Open/Close (1)
JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ])
# fmt: on
# Defines Action Encoding Schemes
class ActionEncoding(IntEnum):
# fmt: off
EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1)
EEF_POS_QUAT = 5 # EEF Delta XYZ (3) + Quaternion (4) + Gripper Open/Close (1)
JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1)
JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ])
EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1)
# fmt: on
# === Individual Dataset Configs ===
OXE_DATASET_CONFIGS = {
"fractal20220817_data": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["base_pose_tool_reached", "gripper_closed"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Google Robot",
},
"kuka": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [
"clip_function_input/base_pose_tool_reached",
"gripper_closed",
],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Kuka iiwa",
},
"bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture
"image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"bridge_orig": { # Original version of Bridge V2 from project website
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"bridge_dataset": { # Original version of Bridge V2 from project website
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "WidowX",
},
"taco_play": {
"image_obs_keys": {
"primary": "rgb_static",
"secondary": None,
"wrist": "rgb_gripper",
},
"depth_obs_keys": {
"primary": "depth_static",
"secondary": None,
"wrist": "depth_gripper",
},
"state_obs_keys": ["state_eef", None, "state_gripper"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
},
"jaco_play": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state_eef", None, "state_gripper"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Jaco 2",
},
"berkeley_cable_routing": {
"image_obs_keys": {
"primary": "image",
"secondary": "top_image",
"wrist": "wrist45_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["robot_state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"roboturk": {
"image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Sawyer",
},
"nyu_door_opening_surprising_effectiveness": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Hello Stretch",
},
"viola": {
"image_obs_keys": {
"primary": "agentview_rgb",
"secondary": None,
"wrist": "eye_in_hand_rgb",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_states", "gripper_states"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"berkeley_autolab_ur5": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "hand_image",
},
"depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "UR5",
},
"toto": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 30,
"robot_type": "Franka",
},
"language_table": {
"image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["effector_translation", None, None, None, None, None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm",
},
"columbia_cairlab_pusht_real": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["robot_state", None, None, None, None, None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "UR5",
},
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["ee_position", "ee_orientation", None],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Kuka iiwa",
},
"nyu_rot_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "xArm",
},
"stanford_hydra_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"austin_buds_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"nyu_franka_play_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": "image_additional_view",
"wrist": None,
},
"depth_obs_keys": {
"primary": "depth",
"secondary": "depth_additional_view",
"wrist": None,
},
"state_obs_keys": ["eef_state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Franka",
},
"maniskill_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {
"primary": "depth",
"secondary": None,
"wrist": "wrist_depth",
},
"state_obs_keys": ["tcp_pose", "gripper_state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"furniture_bench_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"cmu_franka_exploration_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "highres_image",
"secondary": None,
"wrist": None,
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", None],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 2,
"robot_type": "xArm",
},
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "xArm",
},
"austin_sailor_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"austin_sirius_dataset_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"bc_z": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [
"present/xyz",
"present/axis_angle",
None,
"present/sensed_close",
],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Google Robot",
},
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "PR2",
},
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "PR2",
},
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": "image2",
"wrist": "hand_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["end_effector_pose", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm",
},
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["pose_r", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "xArm Bimanual",
},
"robo_net": {
"image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 1,
"robot_type": "Multi-Robot",
},
"berkeley_mvp_converted_externally_to_rlds": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["pose", "gripper"],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 5,
"robot_type": "xArm",
},
"berkeley_rpt_converted_externally_to_rlds": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_pos", "gripper"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 30,
"robot_type": "Franka",
},
"kaist_nonprehensile_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.POS_QUAT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"stanford_mask_vit_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": None,
"robot_type": "Sawyer",
},
"tokyo_u_lsmo_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Cobotta",
},
"dlr_sara_pour_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "DLR SARA",
},
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "DLR SARA",
},
"dlr_edan_shared_control_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "DLR EDAN",
},
"asu_table_top_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 12.5,
"robot_type": "UR5",
},
"stanford_robocook_converted_externally_to_rlds": {
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"imperialcollege_sawyer_wrist_cam": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, "state"],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Sawyer",
},
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", "gripper_state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"uiuc_d3field": {
"image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None},
"depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None},
"state_obs_keys": [None, None, None, None, None, None, None, None],
"state_encoding": StateEncoding.NONE,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 1,
"robot_type": "Kinova Gen3",
},
"utaustin_mutex": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"berkeley_fanuc_manipulation": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "wrist_image",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["joint_state", None, "gripper_state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Fanuc Mate",
},
"cmu_playing_with_food": {
"image_obs_keys": {
"primary": "image",
"secondary": None,
"wrist": "finger_vision_1",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
"cmu_play_fusion": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"cmu_stretch": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["eef_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Hello Stretch",
},
"berkeley_gnm_recon": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3,
"robot_type": "Jackal",
},
"berkeley_gnm_cory_hall": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "RC Car",
},
"berkeley_gnm_sac_son": {
"image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["state", None, None],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "TurtleBot 2",
},
# NOTE: modified
"droid": {
"image_obs_keys": {
"primary": "exterior_image_1_left",
"secondary": "exterior_image_2_left",
"wrist": "wrist_image_left",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
"aux_kwargs": {
"dataset_frame_transform_kwargs": {
"chunk_filter_fn": zero_action_filter,
},
},
},
"fmb_dataset": {
"image_obs_keys": {
"primary": "image_side_1",
"secondary": "image_side_2",
"wrist": "image_wrist_1",
},
"depth_obs_keys": {
"primary": "image_side_1_depth",
"secondary": "image_side_2_depth",
"wrist": "image_wrist_1_depth",
},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Franka",
},
# NOTE: modified
"dobbe": {
"image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 3.75,
"robot_type": "Hello Stretch",
},
"roboset": {
"image_obs_keys": {
"primary": "image_left",
"secondary": "image_right",
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.JOINT,
"action_encoding": ActionEncoding.JOINT_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"rh20t": {
"image_obs_keys": {
"primary": "image_front",
"secondary": "image_side_right",
"wrist": "image_wrist",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 10,
"robot_type": "Flexiv",
},
### T-DROID datasets
"tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_pour_corn_in_pot": { # "pour corn from red bonawl into steel pot" task, 50 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_move_object_onto_plate": { # "move <object> onto plate" task, 150 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_knock_object_over": { # "knock <object> over" task, 70 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
"tdroid_cover_object_with_towel": { # "cover <object> with towel" task, 45 demos @ 5 Hz control
"image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 5,
"robot_type": "Franka",
},
### DROID Finetuning datasets
"droid_wipe": {
"image_obs_keys": {
"primary": "exterior_image_2_left",
"secondary": None,
"wrist": "wrist_image_left",
},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["proprio"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 15,
"robot_type": "Franka",
},
# NOTE: modified
### LIBERO datasets (modified versions)
"libero_spatial_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_object_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_goal_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
"libero_10_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
"control_frequency": 20,
"robot_type": "Franka",
},
}

View File

@@ -0,0 +1,76 @@
"""
Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py
"""
from typing import Any
import tensorflow as tf
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts gripper actions from continuous to binary values (0 and 1).
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
values based on the state that is reached _after_ those intermediate values.
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
chunk of intermediate values as the last action in the trajectory.
The `scan_fn` implements the following logic:
new_actions = np.empty_like(actions)
carry = actions[-1]
for i in reversed(range(actions.shape[0])):
if in_between_mask[i]:
carry = carry
else:
carry = float(open_mask[i])
new_actions[i] = carry
"""
open_mask, closed_mask = actions > 0.95, actions < 0.05
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
is_open_float = tf.cast(open_mask, tf.float32)
def scan_fn(carry, i):
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
return 1 - actions
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
"""
# Note =>> -1 for closing, 1 for opening, 0 for no change
opening_mask, closing_mask = actions < -0.1, actions > 0.1
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
def scan_fn(carry, i):
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
# If no relative grasp, assumes open for whole trajectory
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
start = tf.cond(start == 0, lambda: 1, lambda: start)
# Note =>> -1 for closed, 1 for open
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
return new_actions
# === Bridge-V2 =>> Dataset-Specific Transform ===
def relabel_bridge_actions(traj: dict[str, Any]) -> dict[str, Any]:
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
return traj_truncated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,430 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import time
from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
DROID_SHARDS = 2048
DROID_FPS = 15
DROID_ROBOT_TYPE = "Franka"
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
DROID_FEATURES = {
# true on first step of the episode
"is_first": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode
"is_last": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode if it is a terminal step, True for demos
"is_terminal": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# language_instruction is also stored as "task" to follow LeRobot standard
"language_instruction": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"language_instruction_2": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"language_instruction_3": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"observation.state.gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"observation.state.cartesian_position": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"observation.state.joint_position": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
# Add this new feature to follow LeRobot standard of using joint position + gripper
"observation.state": {
"dtype": "float32",
"shape": (8,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
},
},
# Initially called wrist_image_left
"observation.images.wrist_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
# Initially called exterior_image_1_left
"observation.images.exterior_1_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
# Initially called exterior_image_2_left
"observation.images.exterior_2_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
"action.gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"action.gripper_velocity": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"action.cartesian_position": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"action.cartesian_velocity": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"action.joint_position": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
"action.joint_velocity": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
# This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position]
"action.original": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
},
},
# Add this new feature to follow LeRobot standard of using joint position + gripper
"action": {
"dtype": "float32",
"shape": (8,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
},
},
"discount": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
# Meta data that are the same for all frames in the episode
"task_category": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"building": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"collector_id": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"date": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"camera_extrinsics.wrist_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"camera_extrinsics.exterior_1_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"camera_extrinsics.exterior_2_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"is_episode_successful": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
}
def is_episode_successful(tf_episode_metadata):
# Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8
return "/success/" in tf_episode_metadata["file_path"].numpy().decode()
def generate_lerobot_frames(tf_episode):
m = tf_episode["episode_metadata"]
frame_meta = {
"task_category": m["building"].numpy().decode(),
"building": m["building"].numpy().decode(),
"collector_id": m["collector_id"].numpy().decode(),
"date": m["date"].numpy().decode(),
"camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(),
"camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(),
"camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(),
"is_episode_successful": np.array([is_episode_successful(m)]),
}
for f in tf_episode["steps"]:
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
frame = {
"is_first": np.array([f["is_first"].numpy()]),
"is_last": np.array([f["is_last"].numpy()]),
"is_terminal": np.array([f["is_terminal"].numpy()]),
"language_instruction": f["language_instruction"].numpy().decode(),
"language_instruction_2": f["language_instruction_2"].numpy().decode(),
"language_instruction_3": f["language_instruction_3"].numpy().decode(),
"observation.state.gripper_position": f["observation"]["gripper_position"].numpy(),
"observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(),
"observation.state.joint_position": f["observation"]["joint_position"].numpy(),
"observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(),
"observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(),
"observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(),
"action.gripper_position": f["action_dict"]["gripper_position"].numpy(),
"action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(),
"action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(),
"action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(),
"action.joint_position": f["action_dict"]["joint_position"].numpy(),
"action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(),
"discount": np.array([f["discount"].numpy()]),
"reward": np.array([f["reward"].numpy()]),
"action.original": f["action"].numpy(),
}
# language_instruction is also stored as "task" to follow LeRobot standard
frame["task"] = frame["language_instruction"]
# Add this new feature to follow LeRobot standard of using joint position + gripper
frame["observation.state"] = np.concatenate(
[frame["observation.state.joint_position"], frame["observation.state.gripper_position"]]
)
frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]])
# Meta data that are the same for all frames in the episode
frame.update(frame_meta)
# Cast fp64 to fp32
for key in frame:
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
frame[key] = frame[key].astype(np.float32)
yield frame
def port_droid(
raw_dir: Path,
repo_id: str,
push_to_hub: bool = False,
num_shards: int | None = None,
shard_index: int | None = None,
):
dataset_name = raw_dir.parent.name
version = raw_dir.name
data_dir = raw_dir.parent.parent
builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="")
if num_shards is not None:
tfds_num_shards = builder.info.splits["train"].num_shards
if tfds_num_shards != DROID_SHARDS:
raise ValueError(
f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}."
)
if num_shards != tfds_num_shards:
raise ValueError(
f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead."
)
if shard_index >= tfds_num_shards:
raise ValueError(
f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})."
)
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
else:
raw_dataset = builder.as_dataset(split="train")
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=DROID_ROBOT_TYPE,
fps=DROID_FPS,
features=DROID_FEATURES,
)
start_time = time.time()
num_episodes = raw_dataset.cardinality().numpy().item()
logging.info(f"Number of episodes {num_episodes}")
for episode_index, episode in enumerate(raw_dataset):
elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(
f"{episode_index} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
)
for frame in generate_lerobot_frames(episode):
lerobot_dataset.add_frame(frame)
lerobot_dataset.save_episode()
logging.info("Save_episode")
if push_to_hub:
lerobot_dataset.push_to_hub(
# Add openx tag, since it belongs to the openx collection of datasets
tags=["openx"],
private=False,
)
def validate_dataset(repo_id):
"""Sanity check that ensure meta data can be loaded and all files are present."""
meta = LeRobotDatasetMetadata(repo_id)
if meta.total_episodes == 0:
raise ValueError("Number of episodes is 0.")
for ep_idx in range(meta.total_episodes):
data_path = meta.root / meta.get_data_file_path(ep_idx)
if not data_path.exists():
raise ValueError(f"Parquet file is missing in: {data_path}")
for vid_key in meta.video_keys:
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
if not vid_path.exists():
raise ValueError(f"Video file is missing in: {vid_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.",
)
parser.add_argument(
"--shard-index",
type=int,
default=None,
help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.",
)
args = parser.parse_args()
port_droid(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,359 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
import time
from functools import partial
from pathlib import Path
from typing import Any
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
# Default FPS for datasets without specific config
DEFAULT_FPS = 10
DEFAULT_ROBOT_TYPE = "unknown"
def determine_dataset_info(raw_dir: Path):
"""Determine dataset name and version from directory structure."""
last_part = raw_dir.name
if re.match(r"^\d+\.\d+\.\d+$", last_part):
version = last_part
dataset_name = raw_dir.parent.name
data_dir = raw_dir.parent.parent
else:
version = ""
dataset_name = last_part
data_dir = raw_dir.parent
return dataset_name, version, data_dir
def generate_features_from_builder(builder, dataset_name: str) -> dict[str, Any]:
"""Generate LeRobot features schema from TFDS builder and dataset config."""
# Generate state names based on encoding type
state_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS:
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
if state_encoding == StateEncoding.POS_EULER:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
if "libero" in dataset_name:
state_names = [
"x",
"y",
"z",
"roll",
"pitch",
"yaw",
"gripper",
"gripper",
] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
elif state_encoding == StateEncoding.JOINT:
state_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
pad_count = state_obs_keys[:-1].count(None)
state_names[-pad_count - 1 : -1] = ["pad"] * pad_count
state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1]
# Generate action names based on encoding type
action_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS:
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
if action_encoding == ActionEncoding.EEF_POS:
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
elif action_encoding == ActionEncoding.JOINT_POS:
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
# Base features (state and action)
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": {"axes": state_names},
},
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": {"axes": action_names},
},
}
# Add image features from TFDS builder info
obs_features = builder.info.features["steps"]["observation"]
for key, value in obs_features.items():
# Skip depth images and non-image features
if "depth" in key or not any(x in key for x in ["image", "rgb"]):
continue
features[f"observation.images.{key}"] = {
"dtype": "video",
"shape": tuple(value.shape),
"names": ["height", "width", "channels"],
}
return features
def transform_raw_dataset(episode, dataset_name: str):
"""Apply OXE standardization transforms to raw TFDS episode."""
# Batch all steps in the episode
traj = next(iter(episode["steps"].batch(episode["steps"].cardinality())))
# Apply dataset-specific transform if available
if dataset_name in OXE_STANDARDIZATION_TRANSFORMS:
traj = OXE_STANDARDIZATION_TRANSFORMS[dataset_name](traj)
# Create consolidated state vector
if dataset_name in OXE_DATASET_CONFIGS:
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
else:
state_obs_keys = [None for _ in range(8)]
# Build proprio (proprioceptive state) vector
proprio_components = []
for key in state_obs_keys:
if key is None:
# Add padding for missing state components
component = tf.zeros((tf.shape(traj["action"])[0], 1), dtype=tf.float32)
else:
component = tf.cast(traj["observation"][key], tf.float32)
# Ensure component has right shape (add dimension if needed)
if len(component.shape) == 1:
component = component[:, None]
proprio_components.append(component)
proprio = tf.concat(proprio_components, axis=1)
# Update trajectory with standardized format
traj.update(
{
"proprio": proprio,
"task": traj.get("language_instruction", ""),
"action": tf.cast(traj["action"], tf.float32),
}
)
episode["steps"] = traj
return episode
def generate_lerobot_frames(tf_episode):
"""Generate LeRobot frames from transformed TFDS episode."""
traj = tf_episode["steps"]
# Get the task/language instruction
if isinstance(traj["task"], tf.Tensor):
if traj["task"].dtype == tf.string:
task = traj["task"][0].numpy().decode() if len(traj["task"]) > 0 else ""
else:
task = str(traj["task"][0].numpy()) if len(traj["task"]) > 0 else ""
else:
task = str(traj["task"]) if traj["task"] else ""
# Iterate through each timestep
num_steps = tf.shape(traj["action"])[0].numpy()
for i in range(num_steps):
frame = {}
# Add observation state
frame["observation.state"] = traj["proprio"][i].numpy()
# Add action
frame["action"] = traj["action"][i].numpy()
# Add images
for key, value in traj["observation"].items():
if any(x in key for x in ["image", "rgb"]) and "depth" not in key:
frame[f"observation.images.{key}"] = value[i].numpy()
# Add task
frame["task"] = task
# Cast fp64 to fp32
for key in frame:
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
frame[key] = frame[key].astype(np.float32)
yield frame
def port_rlds(
raw_dir: Path,
repo_id: str,
push_to_hub: bool = False,
num_shards: int | None = None,
shard_index: int | None = None,
):
"""Port RLDS dataset to LeRobot format."""
# Determine dataset info
dataset_name, version, data_dir = determine_dataset_info(raw_dir)
# Build TFDS dataset
builder = tfds.builder(
f"{dataset_name}/{version}" if version else dataset_name, data_dir=data_dir, version=version
)
# Handle sharding if specified
if num_shards is not None and shard_index is not None:
if shard_index >= num_shards:
raise ValueError(f"Shard index {shard_index} >= num_shards {num_shards}")
# Calculate shard splits
total_episodes = builder.info.splits["train"].num_examples
episodes_per_shard = total_episodes // num_shards
start_idx = shard_index * episodes_per_shard
if shard_index == num_shards - 1:
# Last shard gets remaining episodes
end_idx = total_episodes
else:
end_idx = start_idx + episodes_per_shard
split_str = f"train[{start_idx}:{end_idx}]"
raw_dataset = builder.as_dataset(split=split_str)
else:
raw_dataset = builder.as_dataset(split="train")
# Apply filtering (e.g., success filter for kuka)
if dataset_name == "kuka":
raw_dataset = raw_dataset.filter(lambda e: e["success"])
# Apply transformations
raw_dataset = raw_dataset.map(partial(transform_raw_dataset, dataset_name=dataset_name))
# Get dataset configuration
fps = DEFAULT_FPS
robot_type = DEFAULT_ROBOT_TYPE
if dataset_name in OXE_DATASET_CONFIGS:
config = OXE_DATASET_CONFIGS[dataset_name]
fps = config.get("control_frequency", DEFAULT_FPS)
robot_type = config.get("robot_type", DEFAULT_ROBOT_TYPE)
robot_type = robot_type.lower().replace(" ", "_").replace("-", "_")
# Generate features schema
features = generate_features_from_builder(builder, dataset_name)
# Create LeRobot dataset
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=robot_type,
fps=int(fps),
features=features,
)
# Process episodes
start_time = time.time()
num_episodes = raw_dataset.cardinality().numpy().item()
logging.info(f"Number of episodes: {num_episodes}")
for episode_index, episode in enumerate(raw_dataset):
elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(
f"{episode_index} / {num_episodes} episodes processed "
f"(after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
)
# Generate and add frames
for frame in generate_lerobot_frames(episode):
lerobot_dataset.add_frame(frame)
lerobot_dataset.save_episode()
logging.info("Save_episode")
# Push to hub if requested
if push_to_hub:
tags = ["openx", dataset_name]
if robot_type != "unknown":
tags.append(robot_type)
lerobot_dataset.push_to_hub(
tags=tags,
private=False,
)
def validate_dataset(repo_id):
"""Sanity check that ensures metadata can be loaded and all files are present."""
meta = LeRobotDatasetMetadata(repo_id)
if meta.total_episodes == 0:
raise ValueError("Number of episodes is 0.")
for ep_idx in range(meta.total_episodes):
data_path = meta.root / meta.get_data_file_path(ep_idx)
if not data_path.exists():
raise ValueError(f"Parquet file is missing in: {data_path}")
for vid_key in meta.video_keys:
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
if not vid_path.exists():
raise ValueError(f"Video file is missing in: {vid_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Number of shards to split the dataset into for parallel processing.",
)
parser.add_argument(
"--shard-index",
type=int,
default=None,
help="Index of the shard to process (0-indexed).",
)
args = parser.parse_args()
port_rlds(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
class AggregateDatasets(PipelineStep):
def __init__(
self,
repo_ids: list[str],
aggregated_repo_id: str,
):
super().__init__()
self.repo_ids = repo_ids
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
# Since aggregate_datasets already handles parallel processing internally,
# we only need one worker to run the entire aggregation
if rank == 0:
logging.info(f"Starting aggregation of {len(self.repo_ids)} datasets into {self.aggr_repo_id}")
aggregate_datasets(self.repo_ids, self.aggr_repo_id)
logging.info("Aggregation complete!")
else:
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
def make_aggregate_executor(
repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
AggregateDatasets(repo_ids, repo_id),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
# For aggregation, we only need 1 task since aggregate_datasets handles everything
kwargs.update(
{
"job_name": job_name,
"tasks": 1, # Only need 1 task for aggregation
"workers": 1, # Only need 1 worker
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": 1,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=Path,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="aggr_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=1, # Changed default to 1 since aggregation doesn't need multiple workers
help="Number of slurm workers. For aggregation, this should be 1.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)]
aggregate_executor = make_aggregate_executor(repo_ids, **kwargs)
aggregate_executor.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,247 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
def __init__(
self,
raw_dir: Path | str,
repo_id: str = None,
):
super().__init__()
self.raw_dir = Path(raw_dir)
self.repo_id = repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
try:
validate_dataset(shard_repo_id)
return
except Exception:
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
port_droid(
self.raw_dir,
shard_repo_id,
push_to_hub=False,
num_shards=world_size,
shard_index=rank,
)
validate_dataset(shard_repo_id)
class PortRLDSShards(PipelineStep):
def __init__(
self,
raw_dir: Path | str,
repo_id: str = None,
num_shards: int = None,
):
super().__init__()
self.raw_dir = Path(raw_dir)
self.repo_id = repo_id
self.num_shards = num_shards
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_rlds import port_rlds, validate_dataset
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
try:
validate_dataset(shard_repo_id)
return
except Exception:
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
port_rlds(
self.raw_dir,
shard_repo_id,
push_to_hub=False,
num_shards=world_size,
shard_index=rank,
)
validate_dataset(shard_repo_id)
def make_port_executor(
raw_dir,
repo_id,
job_name,
logs_dir,
workers,
partition,
cpus_per_task,
mem_per_cpu,
slurm=True,
dataset_type="droid",
num_shards=None,
):
# Select appropriate pipeline step based on dataset type
if dataset_type.lower() == "droid":
pipeline_step = PortDroidShards(raw_dir, repo_id)
default_shards = DROID_SHARDS
elif dataset_type.lower() == "rlds":
pipeline_step = PortRLDSShards(raw_dir, repo_id, num_shards)
default_shards = num_shards or workers # Use num_shards or fallback to workers
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
kwargs = {
"pipeline": [pipeline_step],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": default_shards,
"workers": workers,
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": 1,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=Path,
default=Path("./logs"),
help="Path to logs directory for `datatrove` (default: ./logs).",
)
parser.add_argument(
"--dataset-type",
type=str,
choices=["droid", "rlds"],
default="droid",
help="Type of dataset to process: 'droid' for DROID datasets or 'rlds' for RLDS/OpenX datasets.",
)
parser.add_argument(
"--job-name",
type=str,
default=None,
help="Job name used in slurm, and name of the directory created inside the provided logs directory. Defaults to 'port_{dataset_type}'.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=None,
help="Number of slurm workers. Defaults: 2048 for DROID, 64 for RLDS datasets.",
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Number of shards to split the dataset into. For DROID datasets, this is fixed at 2048. For RLDS datasets, defaults to number of workers.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
args = parser.parse_args()
# Set defaults based on dataset type
if args.job_name is None:
args.job_name = f"port_{args.dataset_type}"
if args.workers is None:
if args.dataset_type == "droid":
args.workers = 2048
else: # rlds
args.workers = 64
# Convert args to kwargs and process
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
port_executor = make_port_executor(**kwargs)
port_executor.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,281 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
from lerobot.utils.utils import init_logging
class UploadDataset(PipelineStep):
def __init__(
self,
repo_id: str,
branch: str | None = None,
revision: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
private: bool = False,
distant_repo_id: str | None = None,
**card_kwargs,
):
super().__init__()
self.repo_id = repo_id
self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id
self.branch = branch
self.tags = tags
self.license = license
self.private = private
self.card_kwargs = card_kwargs
self.revision = revision if revision else CODEBASE_VERSION
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
logging.warning(
'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
"variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
)
self.create_repo()
def create_repo(self):
logging.info(f"Loading meta data from {self.repo_id}...")
meta = LeRobotDatasetMetadata(self.repo_id)
logging.info(f"Creating repo {self.distant_repo_id}...")
hub_api = HfApi()
hub_api.create_repo(
repo_id=self.distant_repo_id,
private=self.private,
repo_type="dataset",
exist_ok=True,
)
if self.branch:
hub_api.create_branch(
repo_id=self.distant_repo_id,
branch=self.branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)
if not hub_api.file_exists(
self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch
):
card = create_lerobot_dataset_card(
tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
)
card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
hub_api.create_tag(self.distant_repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
def list_files_recursively(directory):
base_path = Path(directory)
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
logging.info(f"Listing all local files from {self.repo_id}...")
self.file_paths = list_files_recursively(meta.root)
self.file_paths = sorted(self.file_paths)
def create_chunks(self, lst, n):
from itertools import islice
it = iter(lst)
return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]
def create_commits(self, additions):
import logging
import math
import random
import time
from huggingface_hub import create_commit
from huggingface_hub.utils import HfHubHTTPError
FILES_BETWEEN_COMMITS = 10 # noqa: N806
BASE_DELAY = 0.1 # noqa: N806
MAX_RETRIES = 12 # noqa: N806
# Split the files into smaller chunks for faster commit
# and avoiding "A commit has happened since" error
num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS)
chunks = self.create_chunks(additions, num_chunks)
for chunk in chunks:
retries = 0
while True:
try:
create_commit(
self.distant_repo_id,
repo_type="dataset",
operations=chunk,
commit_message=f"DataTrove upload ({len(chunk)} files)",
revision=self.branch,
)
# TODO: every 100 chunks super_squach_commits()
logging.info("create_commit completed!")
break
except HfHubHTTPError as e:
if "A commit has happened since" in e.server_message:
if retries >= MAX_RETRIES:
logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
raise e
logging.info("Commit creation race condition issue. Waiting...")
time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
retries += 1
else:
raise e
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from datasets.utils.tqdm import disable_progress_bars
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
chunks = self.create_chunks(self.file_paths, world_size)
file_paths = chunks[rank]
if len(file_paths) == 0:
raise ValueError(file_paths)
logging.info("Pre-uploading LFS files...")
for i, path in enumerate(file_paths):
logging.info(f"{i}: {path}")
meta = LeRobotDatasetMetadata(self.repo_id)
additions = [
CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
]
preupload_lfs_files(
repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch
)
logging.info("Creating commits...")
self.create_commits(additions)
logging.info("Done!")
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": DROID_SHARDS,
"workers": workers,
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": DROID_SHARDS,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=Path,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="upload_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=50,
help="Number of slurm workers. It should be less than the maximum number of shards.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
init_logging()
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
upload_executor = make_upload_executor(**kwargs)
upload_executor.run()
if __name__ == "__main__":
main()

View File

@@ -73,7 +73,6 @@ dependencies = [
"pynput>=1.7.7",
"pyserial>=3.5",
"wandb>=0.20.0",
"scipy>=1.15.2",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
@@ -85,7 +84,6 @@ dependencies = [
# Support dependencies
"deepdiff>=7.0.1,<9.0.0",
"flask>=3.0.3,<4.0.0",
"imageio[ffmpeg]>=2.34.0,<3.0.0",
"termcolor>=2.4.0,<4.0.0",
]
@@ -96,7 +94,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers<=4.52.0"]
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
@@ -112,7 +110,6 @@ intelrealsense = [
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
# stretch = [
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
@@ -154,8 +151,7 @@ all = [
"lerobot[video_benchmark]",
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[xarm]",
"lerobot[phone]",
"lerobot[xarm]"
]
[project.scripts]

View File

@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@@ -53,6 +53,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)

View File

@@ -24,7 +24,6 @@ class FeatureType(str, Enum):
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
LANGUAGE = "LANGUAGE"
class NormalizationMode(str, Enum):

View File

@@ -21,7 +21,6 @@ OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
OBS_LANGUAGE = "observation.language"
ACTION = "action"
REWARD = "next.reward"
@@ -40,9 +39,6 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"

View File

@@ -0,0 +1,505 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import shutil
from pathlib import Path
import pandas as pd
import tqdm
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
get_parquet_file_size_in_mb,
get_video_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concat_video_files
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
"""Validates that all dataset metadata have consistent properties.
Ensures all datasets have the same fps, robot_type, and features to guarantee
compatibility when aggregating them into a single dataset.
Args:
all_metadata: List of LeRobotDatasetMetadata objects to validate.
Returns:
tuple: A tuple containing (fps, robot_type, features) from the first metadata.
Raises:
ValueError: If any metadata has different fps, robot_type, or features
than the first metadata in the list.
"""
fps = all_metadata[0].fps
robot_type = all_metadata[0].robot_type
features = all_metadata[0].features
for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
if fps != meta.fps:
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
if robot_type != meta.robot_type:
raise ValueError(
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
)
if features != meta.features:
raise ValueError(
f"Same features is expected, but got features={meta.features} instead of {features}."
)
return fps, robot_type, features
def update_data_df(df, src_meta, dst_meta):
"""Updates a data DataFrame with new indices and task mappings for aggregation.
Adjusts episode indices, frame indices, and task indices to account for
previously aggregated data in the destination dataset.
Args:
df: DataFrame containing the data to be updated.
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices.
"""
def _update(row):
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
row["index"] = row["index"] + dst_meta.info["total_frames"]
task = src_meta.tasks.iloc[row["task_index"]].name
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
return row
return df.apply(_update, axis=1)
def update_meta_data(
df,
dst_meta,
meta_idx,
data_idx,
videos_idx,
):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset.
Args:
df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata.
meta_idx: Dictionary containing current metadata chunk and file indices.
data_idx: Dictionary containing current data chunk and file indices.
videos_idx: Dictionary containing current video indices and timestamps.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
"""
def _update(row):
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
row[f"videos/{key}/from_timestamp"] = (
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
row[f"videos/{key}/to_timestamp"] = (
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
)
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
return row
return df.apply(_update, axis=1)
def aggregate_datasets(
repo_ids: list[str],
aggr_repo_id: str,
roots: list[Path] | None = None,
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
This is the main function that orchestrates the aggregation process by:
1. Loading and validating all source dataset metadata
2. Creating a new destination dataset with unified tasks
3. Aggregating videos, data, and metadata from all source datasets
4. Finalizing the aggregated dataset with proper statistics
Args:
repo_ids: List of repository IDs for the datasets to aggregate.
aggr_repo_id: Repository ID for the aggregated output dataset.
roots: Optional list of root paths for the source datasets.
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
"""
logging.info("Start aggregate_datasets")
if data_files_size_in_mb is None:
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
all_metadata = (
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
if roots is None
else [
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
dst_meta = LeRobotDatasetMetadata.create(
repo_id=aggr_repo_id,
fps=fps,
robot_type=robot_type,
features=features,
root=aggr_root,
)
logging.info("Find all tasks")
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
meta_idx = {"chunk": 0, "file": 0}
data_idx = {"chunk": 0, "file": 0}
videos_idx = {
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
}
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
finalize_aggregation(dst_meta, all_metadata)
logging.info("Aggregation complete.")
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
"""Aggregates video chunks from a source dataset into the destination dataset.
Handles video file concatenation and rotation based on file size limits.
Creates new video files when size limits are exceeded.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
for key, video_idx in videos_idx.items():
unique_chunk_file_pairs = {
(chunk, file)
for chunk, file in zip(
src_meta.episodes[f"videos/{key}/chunk_index"],
src_meta.episodes[f"videos/{key}/file_index"],
strict=False,
)
}
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
chunk_idx = video_idx["chunk"]
file_idx = video_idx["file"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
file_index=src_file_idx,
)
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
# If a new file is created, we don't want to increment the latest_duration
update_latest_duration = False
if not dst_path.exists():
# First write to this destination file
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
continue # not accumulating further, already copied the file in place
# Check file sizes before appending
src_size = get_video_size_in_mb(src_path)
dst_size = get_video_size_in_mb(dst_path)
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new chunk/file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
else:
# Get the timestamps shift for this video
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
# Append to existing video file
concat_video_files(
[dst_path, src_path],
dst_meta.root,
key,
chunk_idx,
file_idx,
)
# Update the latest_duration when appending (shifts timestamps!)
update_latest_duration = not update_latest_duration
# Update the videos_idx with the final chunk and file indices for this key
videos_idx[key]["chunk"] = chunk_idx
videos_idx[key]["file"] = file_idx
if update_latest_duration:
videos_idx[key]["latest_duration"] += timestamps_shift_s
return videos_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices.
Returns:
dict: Updated data_idx with current chunk and file indices.
"""
unique_chunk_file_ids = {
(c, f)
for c, f in zip(
src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False
)
}
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
)
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
data_idx = append_or_create_parquet_file(
df,
src_path,
data_idx,
data_files_size_in_mb,
chunk_size,
DEFAULT_DATA_PATH,
contains_images=len(dst_meta.image_keys) > 0,
aggr_root=dst_meta.root,
)
return data_idx
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.
Reads source metadata files, updates all indices and timestamps,
and writes them to the destination with proper file rotation.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
meta_idx: Dictionary tracking metadata chunk and file indices.
data_idx: Dictionary tracking data chunk and file indices.
videos_idx: Dictionary tracking video indices and timestamps.
Returns:
dict: Updated meta_idx with current chunk and file indices.
"""
chunk_file_ids = {
(c, f)
for c, f in zip(
src_meta.episodes["meta/episodes/chunk_index"],
src_meta.episodes["meta/episodes/file_index"],
strict=False,
)
}
chunk_file_ids = sorted(chunk_file_ids)
for chunk_idx, file_idx in chunk_file_ids:
src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(src_path)
df = update_meta_data(
df,
dst_meta,
meta_idx,
data_idx,
videos_idx,
)
for k in videos_idx:
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
meta_idx = append_or_create_parquet_file(
df,
src_path,
meta_idx,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_CHUNK_SIZE,
DEFAULT_EPISODES_PATH,
contains_images=False,
aggr_root=dst_meta.root,
)
return meta_idx
def append_or_create_parquet_file(
df: pd.DataFrame,
src_path: Path,
idx: dict[str, int],
max_mb: float,
chunk_size: int,
default_path: str,
contains_images: bool = False,
aggr_root: Path = None,
):
"""Appends data to an existing parquet file or creates a new one based on size constraints.
Manages file rotation when size limits are exceeded to prevent individual files
from becoming too large. Handles both regular parquet files and those containing images.
Args:
df: DataFrame to write to the parquet file.
src_path: Path to the source file (used for size estimation).
idx: Dictionary containing current 'chunk' and 'file' indices.
max_mb: Maximum allowed file size in MB before rotation.
chunk_size: Maximum number of files per chunk before incrementing chunk index.
default_path: Format string for generating file paths.
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
Returns:
dict: Updated index dictionary with current chunk and file indices.
"""
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path)
else:
df.to_parquet(dst_path)
return idx
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path
if contains_images:
to_parquet_with_hf_images(final_df, target_path)
else:
final_df.to_parquet(target_path)
return idx
def finalize_aggregation(aggr_meta, all_metadata):
"""Finalizes the dataset aggregation by writing summary files and statistics.
Writes the tasks file, info file with total counts and splits, and
aggregated statistics from all source datasets.
Args:
aggr_meta: Aggregated dataset metadata.
all_metadata: List of all source dataset metadata objects.
"""
logging.info("write tasks")
write_tasks(aggr_meta.tasks, aggr_meta.root)
logging.info("write info")
aggr_meta.info.update(
{
"total_tasks": len(aggr_meta.tasks),
"total_episodes": sum(m.total_episodes for m in all_metadata),
"total_frames": sum(m.total_frames for m in all_metadata),
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
}
)
write_info(aggr_meta.info, aggr_meta.root)
logging.info("write stats")
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
write_stats(aggr_meta.stats, aggr_meta.root)

View File

@@ -47,6 +47,18 @@ If you encounter a problem, contact LeRobot maintainers on [Discord](https://dis
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
V30_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
```
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
@@ -58,7 +70,12 @@ class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
if version.major == 2 and version.minor == 1:
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
else:
raise NotImplementedError(
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
)
super().__init__(message)

File diff suppressed because it is too large Load Diff

View File

@@ -337,13 +337,11 @@ def compute_sampler_weights(
if len(offline_dataset) > 0:
offline_data_mask_indices = []
for start_index, end_index in zip(
offline_dataset.episode_data_index["from"],
offline_dataset.episode_data_index["to"],
offline_dataset.meta.episodes["dataset_from_index"],
offline_dataset.meta.episodes["dataset_to_index"],
strict=True,
):
offline_data_mask_indices.extend(
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
)
offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
weights.append(

View File

@@ -1,94 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor.pipeline import RobotProcessor
def aggregate_pipeline_dataset_features(
pipeline: RobotProcessor,
initial_features: dict[str, Any],
*,
use_videos: bool = True,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates the pipeline's features and returns a features dict ready for the dataset,
filtered to only those keys matching any of the given patterns (for action/state only).
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
- `use_videos`: whether to treat image features as video streams
- `patterns`: regexes to filter action & state features; images are included
whenever use_videos=True, regardless of patterns.
"""
import re
# Gather everything the pipeline features specifies, seeded with hardware cams:
all_features = pipeline.transform_features(initial_features)
# Helper to decide which action/state keys survive the `patterns` filter:
def keep(key: str) -> bool:
if patterns is None:
return True
return any(re.search(pat, key) for pat in patterns)
# Start with hardware dict, injecting initial cameras if videos are ON:
hw: dict[str, dict[str, Any]] = {}
if use_videos:
cams = {
name: shape
for name, shape in initial_features.items()
if isinstance(shape, tuple) and len(shape) == 3
}
if cams:
hw["observation"] = dict(cams)
# Go over every feature from the pipeline and merge:
for full_key, ty in all_features.items():
if full_key.startswith("action."):
# action.<feat>
if not keep(full_key):
continue
name = full_key[len("action.") :]
hw.setdefault("action", {})[name] = ty
elif full_key.startswith("observation.state."):
# observation.state.<feat>
if not keep(full_key):
continue
name = full_key[len("observation.state.") :]
hw.setdefault("observation", {})[name] = ty
elif full_key.startswith("observation.images."):
# observation.images.<cam>
# images obey ONLY the use_videos flag, not patterns
if not use_videos:
continue
name = full_key[len("observation.images.") :]
hw.setdefault("observation", {})[name] = ty
else:
# anything else (e.g. policy-only features) is ignored here
continue
out: dict[str, dict] = {}
if "action" in hw:
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
if "observation" in hw:
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
return out

View File

@@ -21,7 +21,8 @@ import torch
class EpisodeAwareSampler:
def __init__(
self,
episode_data_index: dict,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
@@ -30,7 +31,8 @@ class EpisodeAwareSampler:
"""Sampler that optionally incorporates episode boundary information.
Args:
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
@@ -39,12 +41,10 @@ class EpisodeAwareSampler:
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
self.indices = indices
self.shuffle = shuffle

View File

@@ -18,42 +18,55 @@ import importlib.resources
import json
import logging
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any
import datasets
import jsonlines
import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.parquet as pq
import torch
from datasets import Dataset, concatenate_datasets
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.backward_compatibility import (
V21_MESSAGE,
FUTURE_MESSAGE,
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.utils.utils import is_valid_numpy_dtype_string
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
EPISODES_DIR = "meta/episodes"
DATA_DIR = "data"
VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DATASET_CARD_TEMPLATE = """
---
@@ -74,6 +87,65 @@ DEFAULT_FEATURES = {
}
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
metadata = pq.read_metadata(parquet_path)
total_uncompressed_size = 0
for row_group in range(metadata.num_row_groups):
rg_metadata = metadata.row_group(row_group)
for column in range(rg_metadata.num_columns):
col_metadata = rg_metadata.column(column)
total_uncompressed_size += col_metadata.total_uncompressed_size
return total_uncompressed_size / (1024**2)
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes // (1024**2)
def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
return None
return Path(hf_ds.cache_files[0]["filename"]).parents[2]
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1:
file_idx = 0
chunk_idx += 1
else:
file_idx += 1
return chunk_idx, file_idx
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
return concatenate_datasets(datasets)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
metadata = pq.read_metadata(parquet_path)
return metadata.num_rows
def get_video_size_in_mb(mp4_path: Path) -> float:
file_size_bytes = mp4_path.stat().st_size
file_size_mb = file_size_bytes / (1024**2)
return file_size_mb
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
@@ -82,6 +154,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
```
"""
items = []
for k, v in d.items():
@@ -106,23 +179,13 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
split_keys = flattened_key.split(sep)
getter = obj[split_keys[0]]
if len(split_keys) == 1:
return getter
for key in split_keys[1:]:
getter = getter[key]
return getter
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
serialized_dict[key] = value
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
@@ -152,24 +215,7 @@ def write_json(data: dict, fpath: Path) -> None:
json.dump(data, f, indent=4, ensure_ascii=False)
def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
def write_info(info: dict, local_dir: Path):
def write_info(info: dict, local_dir: Path) -> None:
write_json(info, local_dir / INFO_PATH)
@@ -180,65 +226,68 @@ def load_info(local_dir: Path) -> dict:
return info
def write_stats(stats: dict, local_dir: Path):
def write_stats(stats: dict, local_dir: Path) -> None:
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_task(task_index: int, task: dict, local_dir: Path):
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, local_dir / TASKS_PATH)
def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
path = local_dir / DEFAULT_TASKS_PATH
path.parent.mkdir(parents=True, exist_ok=True)
tasks.to_parquet(path)
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
return tasks
def write_episode(episode: dict, local_dir: Path):
append_jsonlines(episode, local_dir / EPISODES_PATH)
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures.
Args:
episodes: HuggingFace Dataset containing episode metadata
local_dir: Root directory where the dataset will be stored
"""
episode_size_mb = get_hf_dataset_size_in_mb(episodes)
if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB:
raise NotImplementedError(
f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. "
f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. "
"This function only supports single-file episode metadata. "
)
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
fpath.parent.mkdir(parents=True, exist_ok=True)
episodes.to_parquet(fpath)
def load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
def load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def load_episodes(local_dir: Path) -> datasets.Dataset:
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
# Select episode features/columns containing references to episode data and videos
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
# This is to speedup access to these data, instead of having to load episode stats.
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
return episodes
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
) -> dict[int, dict[str, dict[str, np.ndarray]]]:
return dict.fromkeys(episodes, stats)
@@ -254,7 +303,7 @@ def load_image_as_numpy(
return img_array
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
@@ -299,7 +348,7 @@ def check_version_compatibility(
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
logging.warning(FUTURE_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
@@ -470,56 +519,15 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
return policy_features
def merge_features(*dicts: dict) -> dict:
"""
Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (observation.images.*), last one wins (if they are identical).
"""
out: dict = {}
for d in dicts:
for key, value in d.items():
if not isinstance(value, dict):
out[key] = value
continue
dtype = value.get("dtype")
shape = value.get("shape")
is_vector = (
dtype not in ("image", "video", "string")
and isinstance(shape, tuple)
and len(shape) == 1
and "names" in value
)
if is_vector:
# Initialize or retrieve the accumulating dict for this feature key
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
# Ensure consistent data types across merged entries
if "dtype" in target and dtype != target["dtype"]:
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
# Merge feature names: append only new ones to preserve order without duplicates
seen = set(target["names"])
for n in value["names"]:
if n not in seen:
target["names"].append(n)
seen.add(n)
# Recompute the shape to reflect the updated number of features
target["shape"] = (len(target["names"]),)
else:
# For images/videos and non-1D entries: override with the latest definition
out[key] = value
return out
def create_empty_dataset_info(
codebase_version: str,
fps: int,
features: dict,
use_videos: bool,
robot_type: str | None = None,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> dict:
return {
"codebase_version": codebase_version,
@@ -527,104 +535,17 @@ def create_empty_dataset_info(
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
"total_videos": 0,
"total_chunks": 0,
"chunks_size": DEFAULT_CHUNK_SIZE,
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"fps": fps,
"splits": {},
"data_path": DEFAULT_PARQUET_PATH,
"data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
cumulative_lengths = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
"to": torch.LongTensor(cumulative_lengths),
}
def check_timestamps_sync(
timestamps: np.ndarray,
episode_indices: np.ndarray,
episode_data_index: dict[str, np.ndarray],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
to account for possible numerical error.
Args:
timestamps (np.ndarray): Array of timestamps in seconds.
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
which identifies indices for the end of each episode.
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
raise_value_error (bool): Whether to raise a ValueError if the check fails.
Returns:
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
Raises:
ValueError: If the check fails and `raise_value_error` is True.
"""
if timestamps.shape != episode_indices.shape:
raise ValueError(
"timestamps and episode_indices should have the same shape. "
f"Found {timestamps.shape=} and {episode_indices.shape=}."
)
# Consecutive differences
diffs = np.diff(timestamps)
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
# Check if all remaining diffs are within tolerance
if not np.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = np.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item()
if hasattr(episode_indices[idx], "item")
else episode_indices[idx],
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
@@ -663,7 +584,7 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
return delta_indices
def cycle(iterable):
def cycle(iterable: Any) -> Iterator[Any]:
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
@@ -676,7 +597,7 @@ def cycle(iterable):
iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""
@@ -729,76 +650,28 @@ def create_lerobot_dataset_card(
)
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
>>> ns = IterableNamespace(data)
>>> ns.name
'Alice'
>>> ns.details.age
25
>>> list(ns.keys())
['name', 'details']
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, IterableNamespace(value))
else:
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
return vars(self)[key]
def items(self):
return vars(self).items()
def values(self):
return vars(self).values()
def keys(self):
return vars(self).keys()
def validate_frame(frame: dict, features: dict):
def validate_frame(frame: dict, features: dict) -> None:
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
error_message = validate_features_presence(actual_features, expected_features)
# task is a special required field that's not part of regular features
if "task" not in actual_features:
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
common_features = actual_features & expected_features
for name in common_features - {"task"}:
# Remove task from actual_features for regular feature validation
actual_features_for_validation = actual_features - {"task"}
error_message = validate_features_presence(actual_features_for_validation, expected_features)
common_features = actual_features_for_validation & expected_features
for name in common_features:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - expected_features
@@ -813,7 +686,9 @@ def validate_features_presence(actual_features: set[str], expected_features: set
return error_message
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
def validate_feature_dtype_and_shape(
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
) -> str:
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
@@ -828,7 +703,7 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
) -> str:
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
@@ -845,7 +720,9 @@ def validate_feature_numpy_array(
return error_message
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
def validate_feature_image_or_video(
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
) -> str:
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
@@ -861,13 +738,13 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
return error_message
def validate_feature_string(name: str, value: str):
def validate_feature_string(name: str, value: str) -> str:
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
@@ -891,3 +768,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)

View File

@@ -1,884 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
lerobot/configs/robot/aloha.yaml before running this script.
"""
import traceback
from pathlib import Path
from textwrap import dedent
from lerobot import available_datasets
from lerobot.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
from lerobot.robots.aloha.configuration_aloha import AlohaRobotConfig
LOCAL_DIR = Path("data/")
# spellchecker:off
ALOHA_MOBILE_INFO = {
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://mobile-aloha.github.io/",
"paper": "https://huggingface.co/papers/2401.02117",
"citation_bibtex": dedent(r"""
@inproceedings{fu2024mobile,
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
booktitle = {arXiv},
year = {2024},
}""").lstrip(),
}
ALOHA_STATIC_INFO = {
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://huggingface.co/papers/2304.13705",
"citation_bibtex": dedent(r"""
@article{Zhao2023LearningFB,
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
journal={RSS},
year={2023},
volume={abs/2304.13705},
url={https://huggingface.co/papers/2304.13705}
}""").lstrip(),
}
PUSHT_INFO = {
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://huggingface.co/papers/2303.04137",
"citation_bibtex": dedent(r"""
@article{chi2024diffusionpolicy,
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
journal = {The International Journal of Robotics Research},
year = {2024},
}""").lstrip(),
}
XARM_INFO = {
"license": "mit",
"url": "https://www.nicklashansen.com/td-mpc/",
"paper": "https://huggingface.co/papers/2203.04955",
"citation_bibtex": dedent(r"""
@inproceedings{Hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={ICML},
year={2022}
}
"""),
}
UNITREEH_INFO = {
"license": "apache-2.0",
}
DATASETS = {
"aloha_mobile_cabinet": {
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_chair": {
"single_task": "Push the chairs in front of the desk to place them against it.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_elevator": {
"single_task": "Take the elevator to the 1st floor.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_shrimp": {
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wash_pan": {
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wipe_wine": {
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
**ALOHA_MOBILE_INFO,
},
"aloha_static_battery": {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee_new": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
**ALOHA_STATIC_INFO,
},
"aloha_static_cups_open": {
"single_task": "Pick up the plastic cup and open its lid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_fork_pick_up": {
"single_task": "Pick up the fork and place it on the plate.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pingpong_test": {
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pro_pencil": {
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
**ALOHA_STATIC_INFO,
},
"aloha_static_screw_driver": {
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
**ALOHA_STATIC_INFO,
},
"aloha_static_tape": {
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
**ALOHA_STATIC_INFO,
},
"aloha_static_thread_velcro": {
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_towel": {
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup": {
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup_left": {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
},
"unitreeh1_warehouse": {
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",
},
"asu_table_top": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023modularity,
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
booktitle={Conference on Robot Learning},
pages={1684--1695},
year={2023},
organization={PMLR}
}
@article{zhou2023learning,
title={Learning modular language-conditioned robot policies through attention},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
journal={Autonomous Robots},
pages={1--21},
year={2023},
publisher={Springer}
}""").lstrip(),
},
"austin_buds_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
"paper": "https://huggingface.co/papers/2109.13841",
"citation_bibtex": dedent(r"""
@article{zhu2022bottom,
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
journal={IEEE Robotics and Automation Letters},
volume={7},
number={2},
pages={4126--4133},
year={2022},
publisher={IEEE}
}""").lstrip(),
},
"austin_sailor_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sailor/",
"paper": "https://huggingface.co/papers/2210.11435",
"citation_bibtex": dedent(r"""
@inproceedings{nasiriany2022sailor,
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
booktitle={Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
"austin_sirius_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sirius/",
"paper": "https://huggingface.co/papers/2211.08416",
"citation_bibtex": dedent(r"""
@inproceedings{liu2022robot,
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
booktitle = {Robotics: Science and Systems (RSS)},
year = {2023}
}""").lstrip(),
},
"berkeley_autolab_ur5": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/berkeley-ur5/home",
"citation_bibtex": dedent(r"""
@misc{BerkeleyUR5Website,
title = {Berkeley {UR5} Demonstration Dataset},
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
}""").lstrip(),
},
"berkeley_cable_routing": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/cablerouting/home",
"paper": "https://huggingface.co/papers/2307.08927",
"citation_bibtex": dedent(r"""
@article{luo2023multistage,
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
journal = {arXiv pre-print},
year = {2023},
url = {https://huggingface.co/papers/2307.08927},
}""").lstrip(),
},
"berkeley_fanuc_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
"citation_bibtex": dedent(r"""
@article{fanuc_manipulation2023,
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
year={2023},
}""").lstrip(),
},
"berkeley_gnm_cory_hall": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://huggingface.co/papers/1709.10489",
"citation_bibtex": dedent(r"""
@inproceedings{kahn2018self,
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
pages={5129--5136},
year={2018},
organization={IEEE}
}""").lstrip(),
},
"berkeley_gnm_recon": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/recon-robot",
"paper": "https://huggingface.co/papers/2104.05859",
"citation_bibtex": dedent(r"""
@inproceedings{shah2021rapid,
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
booktitle={5th Annual Conference on Robot Learning },
year={2021},
url={https://openreview.net/forum?id=d_SWJhyKfVw}
}""").lstrip(),
},
"berkeley_gnm_sac_son": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/SACSoN-review",
"paper": "https://huggingface.co/papers/2306.01874",
"citation_bibtex": dedent(r"""
@article{hirose2023sacson,
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
journal={arXiv preprint arXiv:2306.01874},
year={2023}
}""").lstrip(),
},
"berkeley_mvp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://huggingface.co/papers/2203.06173",
"citation_bibtex": dedent(r"""
@InProceedings{Radosavovic2022,
title = {Real-World Robot Learning with Masked Visual Pre-training},
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
booktitle = {CoRL},
year = {2022}
}""").lstrip(),
},
"berkeley_rpt": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://huggingface.co/papers/2306.10007",
"citation_bibtex": dedent(r"""
@article{Radosavovic2023,
title={Robot Learning with Sensorimotor Pre-training},
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
year={2023},
journal={arXiv:2306.10007}
}""").lstrip(),
},
"cmu_franka_exploration_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://human-world-model.github.io/",
"paper": "https://huggingface.co/papers/2308.10901",
"citation_bibtex": dedent(r"""
@inproceedings{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={RSS},
year={2023}
}""").lstrip(),
},
"cmu_play_fusion": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-fusion.github.io/",
"paper": "https://huggingface.co/papers/2312.04549",
"citation_bibtex": dedent(r"""
@inproceedings{chen2023playfusion,
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
booktitle={CoRL},
year={2023}
}""").lstrip(),
},
"cmu_stretch": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robo-affordances.github.io/",
"paper": "https://huggingface.co/papers/2304.08488",
"citation_bibtex": dedent(r"""
@inproceedings{bahl2023affordances,
title={Affordances from Human Videos as a Versatile Representation for Robotics},
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
booktitle={CVPR},
year={2023}
}
@article{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={CoRL},
year={2023}
}""").lstrip(),
},
"columbia_cairlab_pusht_real": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://huggingface.co/papers/2303.04137",
"citation_bibtex": dedent(r"""
@inproceedings{chi2023diffusionpolicy,
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
year={2023}
}""").lstrip(),
},
"conq_hose_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
"citation_bibtex": dedent(r"""
@misc{ConqHoseManipData,
author={Peter Mitrano and Dmitry Berenson},
title={Conq Hose Manipulation Dataset, v1.15.0},
year={2024},
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
}""").lstrip(),
},
"dlr_edan_shared_control": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://ieeexplore.ieee.org/document/9341156",
"citation_bibtex": dedent(r"""
@inproceedings{vogel_edan_2020,
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
language = {en},
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
year = {2020}
}
@inproceedings{quere_shared_2020,
address = {Paris, France},
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
language = {en},
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
year = {2020},
pages = {7},
}""").lstrip(),
},
"dlr_sara_grid_clamp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
"citation_bibtex": dedent(r"""
@article{padalkar2023guided,
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
journal={Research square preprint rs-3289569/v1},
year={2023}
}""").lstrip(),
},
"dlr_sara_pour": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
"citation_bibtex": dedent(r"""
@inproceedings{padalkar2023guiding,
title={Guiding Reinforcement Learning with Shared Control Templates},
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"droid_100": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://droid-dataset.github.io/",
"paper": "https://huggingface.co/papers/2403.12945",
"citation_bibtex": dedent(r"""
@article{khazatsky2024droid,
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
year = {2024},
}""").lstrip(),
},
"fmb": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://functional-manipulation-benchmark.github.io/",
"paper": "https://huggingface.co/papers/2401.08553",
"citation_bibtex": dedent(r"""
@article{luo2024fmb,
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
journal={arXiv preprint arXiv:2401.08553},
year={2024}
}""").lstrip(),
},
"iamlab_cmu_pickup_insert": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
"paper": "https://huggingface.co/papers/2401.14502",
"citation_bibtex": dedent(r"""
@inproceedings{saxena2023multiresolution,
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=WuBv9-IGDUA}
}""").lstrip(),
},
"imperialcollege_sawyer_wrist_cam": {
"tasks_col": "language_instruction",
"license": "mit",
},
"jaco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
"citation_bibtex": dedent(r"""
@software{dass2023jacoplay,
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
title = {CLVR Jaco Play Dataset},
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
version = {1.0.0},
year = {2023}
}""").lstrip(),
},
"kaist_nonprehensile": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
"citation_bibtex": dedent(r"""
@article{kimpre,
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"nyu_door_opening_surprising_effectiveness": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://jyopari.github.io/VINN/",
"paper": "https://huggingface.co/papers/2112.01511",
"citation_bibtex": dedent(r"""
@misc{pari2021surprising,
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
year={2021},
eprint={2112.01511},
archivePrefix={arXiv},
primaryClass={cs.RO}
}""").lstrip(),
},
"nyu_franka_play_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-to-policy.github.io/",
"paper": "https://huggingface.co/papers/2210.10047",
"citation_bibtex": dedent(r"""
@article{cui2022play,
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
journal = {arXiv preprint arXiv:2210.10047},
year = {2022}
}""").lstrip(),
},
"nyu_rot_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://rot-robot.github.io/",
"paper": "https://huggingface.co/papers/2206.15469",
"citation_bibtex": dedent(r"""
@inproceedings{haldar2023watch,
title={Watch and match: Supercharging imitation with regularized optimal transport},
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
booktitle={Conference on Robot Learning},
pages={32--43},
year={2023},
organization={PMLR}
}""").lstrip(),
},
"roboturk": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://roboturk.stanford.edu/dataset_real.html",
"paper": "PAPER",
"citation_bibtex": dedent(r"""
@inproceedings{mandlekar2019scaling,
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
pages={1048--1055},
year={2019},
organization={IEEE}
}""").lstrip(),
},
"stanford_hydra_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/hydra-il-2023",
"paper": "https://huggingface.co/papers/2306.17237",
"citation_bibtex": dedent(r"""
@article{belkhale2023hydra,
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
journal={arxiv},
year={2023}
}""").lstrip(),
},
"stanford_kuka_multimodal_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/visionandtouch",
"paper": "https://huggingface.co/papers/1810.10191",
"citation_bibtex": dedent(r"""
@inproceedings{lee2019icra,
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
year={2019},
url={https://huggingface.co/papers/1810.10191}
}""").lstrip(),
},
"stanford_robocook": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://hshi74.github.io/robocook/",
"paper": "https://huggingface.co/papers/2306.14447",
"citation_bibtex": dedent(r"""
@article{shi2023robocook,
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
journal={arXiv preprint arXiv:2306.14447},
year={2023}
}""").lstrip(),
},
"taco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
"paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911",
"citation_bibtex": dedent(r"""
@inproceedings{rosete2022tacorl,
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
year = {2022}
}
@inproceedings{mees23hulc2,
title={Grounding Language with Visual Affordances over Unstructured Data},
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
year={2023},
address = {London, UK}
}""").lstrip(),
},
"tokyo_u_lsmo": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "URL",
"paper": "https://huggingface.co/papers/2107.05842",
"citation_bibtex": dedent(r"""
@Article{Osa22,
author = {Takayuki Osa},
journal = {The International Journal of Robotics Research},
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
year = {2022},
number = {3},
pages = {291--311},
volume = {41},
}""").lstrip(),
},
"toto": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://toto-benchmark.org/",
"paper": "https://huggingface.co/papers/2306.00942",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023train,
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
year={2023},
}""").lstrip(),
},
"ucsd_kitchen_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@ARTICLE{ucsd_kitchens,
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
title = {{ucsd kitchens Dataset}},
year = {2023},
month = {August}
}""").lstrip(),
},
"ucsd_pick_and_place_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://owmcorl.github.io/#",
"paper": "https://huggingface.co/papers/2310.16029",
"citation_bibtex": dedent(r"""
@preprint{Feng2023Finetuning,
title={Finetuning Offline World Models in the Real World},
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
year={2023}
}""").lstrip(),
},
"uiuc_d3field": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robopil.github.io/d3fields/",
"paper": "https://huggingface.co/papers/2309.16118",
"citation_bibtex": dedent(r"""
@article{wang2023d3field,
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
journal={arXiv preprint arXiv:},
year={2023},
}""").lstrip(),
},
"usc_cloth_sim": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://uscresl.github.io/dmfd/",
"paper": "https://huggingface.co/papers/2207.10148",
"citation_bibtex": dedent(r"""
@article{salhotra2022dmfd,
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
journal={IEEE Robotics and Automation Letters},
title={Learning Deformable Object Manipulation From Expert Demonstrations},
year={2022},
volume={7},
number={4},
pages={8775-8782},
doi={10.1109/LRA.2022.3187843}
}""").lstrip(),
},
"utaustin_mutex": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/MUTEX/",
"paper": "https://huggingface.co/papers/2309.14320",
"citation_bibtex": dedent(r"""
@inproceedings{shah2023mutex,
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=PwqiqaaEzJ}
}""").lstrip(),
},
"utokyo_pr2_opening_fridge": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_pr2_tabletop_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_saytap": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://saytap.github.io/",
"paper": "https://huggingface.co/papers/2306.07580",
"citation_bibtex": dedent(r"""
@article{saytap2023,
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
Tatsuya Harada},
title = {SayTap: Language to Quadrupedal Locomotion},
eprint = {arXiv:2306.07580},
url = {https://saytap.github.io},
note = {https://saytap.github.io},
year = {2023}
}""").lstrip(),
},
"utokyo_xarm_bimanual": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"utokyo_xarm_pick_and_place": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"viola": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/VIOLA/",
"paper": "https://huggingface.co/papers/2210.11339",
"citation_bibtex": dedent(r"""
@article{zhu2022viola,
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
journal={6th Annual Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
}
# spellchecker:on
def batch_convert():
status = {}
logfile = LOCAL_DIR / "conversion_log.txt"
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
for num, (name, kwargs) in enumerate(DATASETS.items()):
repo_id = f"lerobot/{name}"
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
print("---------------------------------------------------------")
try:
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
status = f"{repo_id}: success."
with open(logfile, "a") as file:
file.write(status + "\n")
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
continue
if __name__ == "__main__":
batch_convert()

View File

@@ -1,687 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task.
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
one episode to the next.
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
recorded with. For now, only Aloha/Koch type robots are supported with this option.
# 1. Single task dataset
If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option.
Examples:
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id lerobot/aloha_sim_insertion_human_image \
--single-task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data
```
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id aliberts/koch_tutorial \
--single-task "Pick the Lego block and drop it in the box on the right." \
--robot-config lerobot/configs/robot/koch.yaml \
--local-dir data
```
# 2. Single task episodes
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
this column's name with the '--tasks-col' arg.
Example:
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
'--tasks-path' arg. This file should have the following structure where keys correspond to each
episode_index in the dataset, and values are the language instruction for that episode.
Example:
```json
{
"0": "Do something",
"1": "Do something else",
"2": "Do something",
"3": "Go there",
...
}
```
# 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg.
Example:
```bash
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
"""
import argparse
import contextlib
import filecmp
import json
import logging
import math
import shutil
import subprocess
import tempfile
from pathlib import Path
import datasets
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
from safetensors.torch import load_file
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_safe_version,
load_json,
unflatten_dict,
write_json,
write_jsonlines,
)
from lerobot.datasets.video_utils import (
VideoFrame, # noqa: F401
get_image_pixel_channels,
get_video_info,
)
from lerobot.robots import RobotConfig
V16 = "v1.6"
V20 = "v2.0"
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
if robot_cfg.type in ["aloha", "koch"]:
state_names = [
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
for arm in robot_cfg.follower_arms
for motor in robot_cfg.follower_arms[arm].motors
]
action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
for arm in robot_cfg.leader_arms
for motor in robot_cfg.leader_arms[arm].motors
]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
raise NotImplementedError(
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
)
return {
"robot_type": robot_cfg.type,
"names": {
"observation.state": state_names,
"observation.effort": state_names,
"action": action_names,
},
}
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path)
serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
json.dump(serialized_stats, f, indent=4)
# Sanity check
with open(json_path) as f:
stats_json = json.load(f)
stats_json = flatten_dict(stats_json)
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
for key in stats:
torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: RobotConfig | None = None
) -> dict[str, list]:
robot_config = parse_robot_config(robot_config)
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
dtype = ft.dtype
shape = (1,)
names = None
if isinstance(ft, datasets.Sequence):
assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
elif isinstance(ft, datasets.Image):
dtype = "image"
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.height, image.width, channels)
names = ["height", "width", "channels"]
elif ft._type == "VideoFrame":
dtype = "video"
shape = None # Add shape later
names = ["height", "width", "channels"]
features[key] = {
"dtype": dtype,
"shape": shape,
"names": names,
}
return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
return dataset, tasks
def add_task_index_from_tasks_col(
dataset: Dataset, tasks_col: str
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
df = dataset.to_pandas()
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
# Build the dataset back from df
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
dataset = dataset.remove_columns(tasks_col)
return dataset, tasks, tasks_by_episode
def split_parquet_by_episodes(
dataset: Dataset,
total_episodes: int,
total_chunks: int,
output_dir: Path,
) -> list:
table = dataset.data.table
episode_lengths = []
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
pq.write_table(ep_table, output_file)
return episode_lengths
def move_videos(
repo_id: str,
video_keys: list[str],
total_episodes: int,
total_chunks: int,
work_dir: Path,
clean_gittatributes: Path,
branch: str = "main",
) -> None:
"""
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
commands to fetch git lfs video files references to move them into subdirectories without having to
actually download them.
"""
_lfs_clone(repo_id, work_dir, branch)
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
current_gittatributes = work_dir / ".gitattributes"
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
if lfs_untracked_videos:
fix_lfs_video_files_tracking(work_dir, video_files)
if videos_moved:
return
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys:
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key
)
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
for dir in video_dirs:
if (dir / video_file).is_file():
video_path = dir / video_file
break
video_path.rename(work_dir / target_path)
commit_message = "Move video files into chunk subdirectories"
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
"""
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
commit_message = "Track video files with git lfs"
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
check=True,
env=env,
)
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
return videos_info_dict
def convert_dataset(
repo_id: str,
local_dir: Path,
single_task: str | None = None,
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: RobotConfig | None = None,
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
v20_dir.mkdir(parents=True, exist_ok=True)
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
)
branch = "main"
if test_branch:
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
features = get_features_from_hf_dataset(dataset, robot_config)
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
if single_task and "language_instruction" in dataset.column_names:
logging.warning(
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
)
single_task = None
tasks_col = "language_instruction"
# Episodes & chunks
episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(video_keys)
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1
# Tasks
if single_task:
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
else:
raise ValueError
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
# Videos
if video_keys:
assert metadata_v1.get("video", False)
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
videos_info[key].pop("video.width"),
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
if "encoding" in metadata_v1:
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
if robot_config is not None:
robot_type = robot_config.type
repo_tags = [robot_type]
else:
robot_type = "unknown"
repo_tags = None
# Episodes
episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V20,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": len(dataset),
"total_tasks": len(tasks),
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
"features": features,
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
folder_path=v20_dir / "data",
repo_type="dataset",
revision=branch,
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
folder_path=v20_dir / "meta",
repo_type="dataset",
revision=branch,
)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
if not test_branch:
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
if robot_type == "aloha":
raise NotImplementedError # TODO
elif robot_type == "koch_follower":
from lerobot.robots.koch_follower import KochFollowerConfig
return KochFollowerConfig(**kwargs)
elif robot_type == "so100_follower":
from lerobot.robots.so100_follower import SO100FollowerConfig
return SO100FollowerConfig(**kwargs)
elif robot_type == "stretch":
from lerobot.robots.stretch3 import Stretch3RobotConfig
return Stretch3RobotConfig(**kwargs)
elif robot_type == "lekiwi":
from lerobot.robots.lekiwi import LeKiwiConfig
return LeKiwiConfig(**kwargs)
else:
raise ValueError(f"Robot type '{robot_type}' is not available.")
def main():
parser = argparse.ArgumentParser()
task_args = parser.add_mutually_exclusive_group(required=True)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the single task performed in the dataset.",
)
task_args.add_argument(
"--tasks-col",
type=str,
help="The name of the column containing language instructions",
)
task_args.add_argument(
"--tasks-path",
type=Path,
help="The path to a .json file containing one language instruction for each episode_index",
)
parser.add_argument(
"--robot",
type=str,
default=None,
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
)
parser.add_argument(
"--local-dir",
type=Path,
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
)
parser.add_argument(
"--license",
type=str,
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser.add_argument(
"--test-branch",
type=str,
default=None,
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
)
args = parser.parse_args()
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
if args.robot is not None:
robot_config = make_robot_config(args.robot)
del args.robot
convert_dataset(**vars(args), robot_config=robot_config)
if __name__ == "__main__":
main()

View File

@@ -1,87 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import traceback
from pathlib import Path
from datasets import get_dataset_config_info
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import INFO_PATH, write_info
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
LOCAL_DIR = Path("data/")
hub_api = HfApi()
def fix_dataset(repo_id: str) -> str:
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
return f"{repo_id}: skipped (not in {V20})."
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)
commit_info = hub_api.upload_file(
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
path_in_repo=INFO_PATH,
repo_id=repo_id,
repo_type="dataset",
revision=V20,
commit_message="Remove 'language_instruction'",
create_pr=True,
)
return f"{repo_id}: success - PR: {commit_info.pr_url}"
def batch_fix():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "fix_features_v20.txt"
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
status = fix_dataset(repo_id)
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status)
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_fix()

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
"""
import traceback
from pathlib import Path
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
LOCAL_DIR = Path("data/")
def batch_convert():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "conversion_log_v21.txt"
hub_api = HfApi()
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
status = f"{repo_id}: success (already in {V21})."
else:
convert_dataset(repo_id)
status = f"{repo_id}: success."
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_convert()

View File

@@ -1,114 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
Usage:
```bash
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
--repo-id=aliberts/koch_tutorial
```
"""
import argparse
import logging
from huggingface_hub import HfApi
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
V20 = "v2.0"
V21 = "v2.1"
class SuppressWarnings:
def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR)
def __exit__(self, exc_type, exc_val, exc_tb):
logging.getLogger().setLevel(self.previous_level)
def convert_dataset(
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
type=str,
default=None,
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
args = parser.parse_args()
convert_dataset(**vars(args))

View File

@@ -1,99 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
video_frames = dataset._query_videos(query_timestamps, episode_index)
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}
for key, ft in dataset.features.items():
if ft["dtype"] == "video":
# We sample only for videos
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
else:
ep_ft_data = np.array(ep_data[key])
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
if num_workers > 0:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
for ep_idx in range(total_episodes)
}
for future in tqdm(as_completed(futures), total=total_episodes):
future.result()
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
):
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
for key, ft in dataset.features.items():
# These values might need some fine-tuning
if ft["dtype"] == "video":
# to account for image sub-sampling
rtol, atol = video_rtol_atol
else:
rtol, atol = default_rtol_atol
for stat, val in agg_stats[key].items():
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
)

View File

@@ -0,0 +1,491 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.1 to
3.0. It will:
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v3.0".
Usage:
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht
```
"""
import argparse
import shutil
from pathlib import Path
from typing import Any
import jsonlines
import pandas as pd
import pyarrow as pa
import tqdm
from datasets import Dataset, Features, Image
from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
LEGACY_EPISODES_PATH,
LEGACY_EPISODES_STATS_PATH,
LEGACY_TASKS_PATH,
cast_stats_to_numpy,
flatten_dict,
get_parquet_file_size_in_mb,
get_parquet_num_frames,
get_video_size_in_mb,
load_info,
update_chunk_file_indices,
write_episodes,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concat_video_files, get_video_duration_in_s
V21 = "v2.1"
"""
-------------------------
OLD
data/chunk-000/episode_000000.parquet
NEW
data/chunk-000/file_000.parquet
-------------------------
OLD
videos/chunk-000/CAMERA/episode_000000.mp4
NEW
videos/chunk-000/file_000.mp4
-------------------------
OLD
episodes.jsonl
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
NEW
meta/episodes/chunk-000/episodes_000.parquet
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
-------------------------
OLD
tasks.jsonl
{"task_index": 1, "task": "Put the blue block in the green bowl"}
NEW
meta/tasks/chunk-000/file_000.parquet
task_index | task
-------------------------
OLD
episodes_stats.jsonl
NEW
meta/episodes_stats/chunk-000/file_000.parquet
episode_index | mean | std | min | max
-------------------------
UPDATE
meta/info.json
-------------------------
"""
def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def legacy_load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def legacy_load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def convert_tasks(root, new_root):
tasks, _ = legacy_load_tasks(root)
task_indices = tasks.keys()
task_strings = tasks.values()
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
write_tasks(df_tasks, new_root)
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
# Concatenate all DataFrames along rows
concatenated_df = pd.concat(dataframes, ignore_index=True)
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(image_keys) > 0:
schema = pa.Schema.from_pandas(concatenated_df)
features = Features.from_arrow_schema(schema)
for key in image_keys:
features[key] = Image()
schema = features.arrow_schema
else:
schema = None
concatenated_df.to_parquet(path, index=False, schema=schema)
def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
data_dir = root / "data"
ep_paths = sorted(data_dir.glob("*/*.parquet"))
image_keys = get_image_keys(root)
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
num_frames = 0
paths_to_cat = []
episodes_metadata = []
for ep_path in ep_paths:
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
ep_metadata = {
"episode_index": ep_idx,
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"dataset_from_index": num_frames,
"dataset_to_index": num_frames + ep_num_frames,
}
size_in_mb += ep_size_in_mb
num_frames += ep_num_frames
episodes_metadata.append(ep_metadata)
ep_idx += 1
if size_in_mb < data_file_size_in_mb:
paths_to_cat.append(ep_path)
continue
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Reset for the next file
size_in_mb = ep_size_in_mb
num_frames = ep_num_frames
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Write remaining data if any
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
return episodes_metadata
def get_video_keys(root):
info = load_info(root)
features = info["features"]
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
return video_keys
def get_image_keys(root):
info = load_info(root)
features = info["features"]
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
return image_keys
def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
video_keys = get_video_keys(root)
if len(video_keys) == 0:
return None
video_keys = sorted(video_keys)
eps_metadata_per_cam = []
for camera in video_keys:
eps_metadata = convert_videos_of_camera(root, new_root, camera, video_file_size_in_mb)
eps_metadata_per_cam.append(eps_metadata)
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
if len(set(num_eps_per_cam)) != 1:
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
episods_metadata = []
num_cameras = len(video_keys)
num_episodes = num_eps_per_cam[0]
for ep_idx in range(num_episodes):
# Sanity check
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
ep_ids += [ep_idx]
if len(set(ep_ids)) != 1:
raise ValueError(f"All episode indices need to match ({ep_ids}).")
ep_dict = {}
for cam_idx in range(num_cameras):
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
episods_metadata.append(ep_dict)
return episods_metadata
def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_file_size_in_mb: int):
# Access old paths to mp4
videos_dir = root / "videos"
ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
episodes_metadata = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Check if adding this episode would exceed the limit
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
# Size limit would be exceeded, save current accumulation WITHOUT this episode
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
# Update episodes metadata for the file we just saved
for i, _ in enumerate(paths_to_cat):
past_ep_idx = ep_idx - len(paths_to_cat) + i
episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
# Move to next file and start fresh with current episode
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
# Add current episode metadata
ep_metadata = {
"episode_index": ep_idx,
f"videos/{video_key}/chunk_index": chunk_idx, # Will be updated when file is saved
f"videos/{video_key}/file_index": file_idx, # Will be updated when file is saved
f"videos/{video_key}/from_timestamp": duration_in_s,
f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
}
episodes_metadata.append(ep_metadata)
# Add current episode to accumulation
paths_to_cat.append(ep_path)
size_in_mb += ep_size_in_mb
duration_in_s += ep_duration_in_s
ep_idx += 1
# Write remaining videos if any
if paths_to_cat:
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
# Update episodes metadata for the final file
for i, _ in enumerate(paths_to_cat):
past_ep_idx = ep_idx - len(paths_to_cat) + i
episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
return episodes_metadata
def generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
):
num_episodes = len(episodes_metadata)
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
episodes_stats_vals = list(episodes_stats.values())
episodes_stats_keys = list(episodes_stats.keys())
for i in range(num_episodes):
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
ep_metadata = episodes_metadata[i]
ep_stats = episodes_stats_vals[i]
ep_ids_set = {
ep_legacy_metadata["episode_index"],
ep_metadata["episode_index"],
episodes_stats_keys[i],
}
if episodes_videos is None:
ep_video = {}
else:
ep_video = episodes_videos[i]
ep_ids_set.add(ep_video["episode_index"])
if len(ep_ids_set) != 1:
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
ep_dict["meta/episodes/chunk_index"] = 0
ep_dict["meta/episodes/file_index"] = 0
yield ep_dict
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
episodes_legacy_metadata = legacy_load_episodes(root)
episodes_stats = legacy_load_episodes_stats(root)
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
if episodes_video_metadata is not None:
num_eps_set.add(len(episodes_video_metadata))
if len(num_eps_set) != 1:
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
ds_episodes = Dataset.from_generator(
lambda: generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
)
)
write_episodes(ds_episodes, new_root)
stats = aggregate_stats(list(episodes_stats.values()))
write_stats(stats, new_root)
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
info = load_info(root)
info["codebase_version"] = "v3.0"
del info["total_chunks"]
del info["total_videos"]
info["data_files_size_in_mb"] = data_file_size_in_mb
info["video_files_size_in_mb"] = video_file_size_in_mb
info["data_path"] = DEFAULT_DATA_PATH
info["video_path"] = DEFAULT_VIDEO_PATH
info["fps"] = float(info["fps"])
for key in info["features"]:
if info["features"][key]["dtype"] == "video":
# already has fps in video_info
continue
info["features"][key]["fps"] = info["fps"]
write_info(info, new_root)
def convert_dataset(
repo_id: str,
branch: str | None = None,
data_file_size_in_mb: int | None = None,
video_file_size_in_mb: int | None = None,
):
root = HF_LEROBOT_HOME / repo_id
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
if data_file_size_in_mb is None:
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_file_size_in_mb is None:
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if old_root.is_dir() and root.is_dir():
shutil.rmtree(str(root))
shutil.move(str(old_root), str(root))
if new_root.is_dir():
shutil.rmtree(new_root)
snapshot_download(
repo_id,
repo_type="dataset",
revision=V21,
local_dir=root,
)
convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb)
convert_tasks(root, new_root)
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
shutil.move(str(root), str(old_root))
shutil.move(str(new_root), str(root))
hub_api = HfApi()
try:
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
except HTTPError as e:
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
pass
hub_api.delete_files(
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
repo_id=repo_id,
revision=branch,
repo_type="dataset",
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
LeRobotDataset(repo_id).push_to_hub()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
type=str,
default=None,
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--data-file-size-in-mb",
type=int,
default=None,
help="File size in MB. Defaults to 100 for data and 500 for videos.",
)
parser.add_argument(
"--video-file-size-in-mb",
type=int,
default=None,
help="File size in MB. Defaults to 100 for data and 500 for videos.",
)
args = parser.parse_args()
convert_dataset(**vars(args))

View File

@@ -17,6 +17,8 @@ import glob
import importlib
import logging
import shutil
import subprocess
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
@@ -29,6 +31,8 @@ import torchvision
from datasets.features.features import register_feature
from PIL import Image
from lerobot.datasets.utils import DEFAULT_VIDEO_PATH
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
@@ -263,7 +267,7 @@ def encode_video_frames(
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
video_path.parent.mkdir(parents=True, exist_ok=overwrite)
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
@@ -273,9 +277,9 @@ def encode_video_frames(
pix_fmt = "yuv420p"
# Get input frames
template = "frame_" + ("[0-9]" * 6) + ".png"
template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0])
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
# Define video output frame size (assuming all input frames are the same size)
@@ -300,7 +304,7 @@ def encode_video_frames(
# Set logging level
if log_level is not None:
# "While less efficient, it is generally preferable to modify logging with Pythons logging"
# "While less efficient, it is generally preferable to modify logging with Python's logging"
logging.getLogger("libav").setLevel(log_level)
# Create and open output file (overwrite by default)
@@ -331,6 +335,62 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int):
"""
Concatenate multiple video files into a single video file using ffmpeg.
This function takes a list of video file paths and concatenates them into a single
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
concatenation without re-encoding.
Args:
paths_to_cat: List of video file paths to concatenate, in order.
root: Root directory where temporary files and output will be created.
video_key: Video key identifier (e.g., camera name) used in output path.
chunk_idx: Chunk index for organizing output files.
file_idx: File index within the chunk.
Note:
- Creates a temporary directory for intermediate files that is cleaned up after use.
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
codec, resolution, and frame rate for proper concatenation.
- Output path follows the DEFAULT_VIDEO_PATH pattern with video_key, chunk_idx,
and file_idx parameters.
- This function uses subprocess to call ffmpeg directly because PyAV doesn't have
built-in support for video concatenation. The concat demuxer in ffmpeg handles
all the complex timestamp adjustments automatically.
"""
tmp_dir = Path(tempfile.mkdtemp(dir=root))
path_concat_video_files = tmp_dir / "concat_video_files.txt"
with open(path_concat_video_files, "w") as f:
for ep_path in paths_to_cat:
f.write(f"file '{str(ep_path)}'\n")
path_tmp_output = tmp_dir / "tmp_output.mp4"
command = [
"ffmpeg",
"-y",
"-f",
"concat",
"-safe",
"0",
"-i",
str(path_concat_video_files),
"-c",
"copy",
str(path_tmp_output),
]
subprocess.run(command, check=True)
output_path = root / DEFAULT_VIDEO_PATH.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
output_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(path_tmp_output), str(output_path))
shutil.rmtree(str(tmp_dir))
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
@@ -454,64 +514,23 @@ def get_image_pixel_channels(image: Image):
raise ValueError("Unknown format")
class VideoEncodingManager:
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
This manager handles:
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Get the duration of a video file in seconds using PyAV.
Args:
dataset: The LeRobotDataset instance
video_path: Path to the video file.
Returns:
Duration of the video in seconds.
"""
def __init__(self, dataset):
self.dataset = dataset
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Handle any remaining episodes that haven't been batch encoded
if self.dataset.episodes_since_last_encoding > 0:
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logging.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logging.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset.batch_encode_videos(start_ep, end_ep)
# Clean up episode images if recording was interrupted
if exc_type is not None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logging.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"
# Check for any remaining PNG files
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
# Only remove the images directory if no PNG files remain
if img_dir.exists():
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
with av.open(str(video_path)) as container:
# Get the first video stream
video_stream = container.streams.video[0]
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
if video_stream.duration is not None:
duration = float(video_stream.duration * video_stream.time_base)
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception
# Fallback to container duration if stream duration is not available
duration = float(container.duration / av.time_base)
return duration

View File

@@ -161,71 +161,33 @@ class XarmEnv(EnvConfig):
@dataclass
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class RewardClassifierConfig:
"""Configuration for reward classification."""
pretrained_path: str | None = None
success_threshold: float = 0.5
success_reward: float = 1.0
@dataclass
class InverseKinematicsConfig:
"""Configuration for inverse kinematics processing."""
urdf_path: str | None = None
target_frame_name: str | None = None
end_effector_bounds: dict[str, list[float]] | None = None
end_effector_step_sizes: dict[str, float] | None = None
@dataclass
class ObservationConfig:
"""Configuration for observation processing."""
class EnvTransformConfig:
"""Configuration for environment wrappers."""
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
control_mode: str = "gamepad"
display_cameras: bool = False
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
@dataclass
class GripperConfig:
"""Configuration for gripper control and penalties."""
use_gripper: bool = True
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@dataclass
class ResetConfig:
"""Configuration for environment reset behavior."""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Any | None = None
reset_time_s: float = 5.0
control_time_s: float = 20.0
terminate_on_success: bool = True
@dataclass
class HILSerlProcessorConfig:
"""Configuration for environment processing pipeline."""
control_mode: str = "gamepad"
observation: ObservationConfig | None = None
image_preprocessing: ImagePreprocessingConfig | None = None
gripper: GripperConfig | None = None
reset: ResetConfig | None = None
inverse_kinematics: InverseKinematicsConfig | None = None
reward_classifier: RewardClassifierConfig | None = None
max_gripper_pos: float | None = 100.0
use_gripper: bool = True
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@@ -235,10 +197,77 @@ class HILSerlRobotEnvConfig(EnvConfig):
robot: RobotConfig | None = None
teleop: TeleoperatorConfig | None = None
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
wrapper: EnvTransformConfig | None = None
fps: int = 10
name: str = "real_robot"
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
task: str | None = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None
reward_classifier_pretrained_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
@property
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("hil")
@dataclass
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
name: str = "PandaPickCube"
task: str | None = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
state_dim: int = 18
action_dim: int = 4
fps: int = 100
episode_length: int = 100
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_STATE,
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: str | None = None
robot_config: RobotConfig | None = None
teleop_config: TeleoperatorConfig | None = None
wrapper: EnvTransformConfig | None = None
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
############################
@property
def gym_kwargs(self) -> dict:
return {
"use_viewer": self.use_viewer,
"use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty,
}

View File

@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")

View File

@@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = {
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],

View File

@@ -15,17 +15,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0.processor_pi0 import Pi0NewLineProcessor
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"PI0Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
]

View File

@@ -35,6 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.constants import ACTION, OBS_IMAGES
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -50,16 +51,27 @@ class ACTPolicy(PreTrainedPolicy):
def __init__(
self,
config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
@@ -125,19 +137,23 @@ class ACTPolicy(PreTrainedPolicy):
"""Predict a chunk of actions given environment observations."""
self.eval()
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
actions = self.model(batch)[0]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
@@ -287,7 +303,7 @@ class ACT(nn.Module):
└───────────────────────┘
"""
def __init__(self, config: ACTConfig, dataset_stats=None):
def __init__(self, config: ACTConfig):
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
super().__init__()

View File

@@ -1,51 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_act_processor(
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -35,6 +35,7 @@ from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
get_device_from_parameters,
@@ -56,6 +57,7 @@ class DiffusionPolicy(PreTrainedPolicy):
def __init__(
self,
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -68,6 +70,14 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -96,6 +106,9 @@ class DiffusionPolicy(PreTrainedPolicy):
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -124,6 +137,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
@@ -139,9 +153,11 @@ class DiffusionPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
return loss, None

View File

@@ -1,52 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_diffusion_processor(
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -14,14 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Any, TypedDict, cast
import torch
from torch import nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
@@ -39,10 +34,9 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import RobotProcessor
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@@ -107,123 +101,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
raise ValueError(f"Policy type '{policy_type}' is not available.")
class ProcessorConfigKwargs(TypedDict, total=False):
"""Keyword arguments for the processor config."""
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
def make_processor(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[RobotProcessor, RobotProcessor]:
"""Make a processor instance for a given policy type.
This function creates the appropriate processor configuration based on the policy type.
Each policy type has its own processor with specific preprocessing steps.
Args:
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
the processor from this path instead of creating a new one.
**kwargs: Additional keyword arguments passed to the processor creation.
Returns:
Tuple of (input_processor, output_processor) for the policy.
Raises:
NotImplementedError: If the policy type doesn't have a processor implemented.
"""
if pretrained_path:
return (
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
overrides=kwargs.get("preprocessor_overrides", {}),
),
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
overrides=kwargs.get("postprocessor_overrides", {}),
),
)
# Create a new processor based on policy type
if policy_cfg.type == "tdmpc":
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
processors = make_tdmpc_processor(
config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "diffusion":
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
processors = make_diffusion_processor(
cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "act":
from lerobot.policies.act.processor_act import make_act_processor
processors = make_act_processor(
config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "vqbet":
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
processors = make_vqbet_processor(
config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "pi0":
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
processors = make_pi0_processor(
config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "pi0fast":
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
processors = make_pi0fast_processor(
cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "sac":
from lerobot.policies.sac.processor_sac import make_sac_processor
processors = make_sac_processor(
cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "reward_classifier":
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "smolvla":
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
processors = make_smolvla_processor(
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
return processors
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
@@ -270,6 +147,7 @@ def make_policy(
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
@@ -277,8 +155,6 @@ def make_policy(
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
if env_cfg is None:
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}

View File

@@ -0,0 +1,420 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
assert isinstance(norm_mode, NormalizationMode)
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model."
)
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(norm_mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
but is factored out so it can be reused by both classes and stay in sync.
"""
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape: tuple[int, ...] = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# reduce spatial dimensions, keep channel dimension only
c, *_ = shape
shape = (c, 1, 1)
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_mean", mean)
module.register_buffer(f"{prefix}_std", std)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_min", min_val)
module.register_buffer(f"{prefix}_max", max_val)
continue
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
batch[key] = batch[key] * 2 - 1
continue
raise ValueError(norm_mode)
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max_val - min_val) + min_val
continue
raise ValueError(norm_mode)
return batch

View File

@@ -56,15 +56,18 @@ from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
PaliGemmaWithExpertModel,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.utils import get_safe_dtype
from lerobot.policies.utils import log_model_loading_keys
from lerobot.utils.utils import get_safe_dtype, init_logging
def create_sinusoidal_pos_embedding(
@@ -220,17 +223,28 @@ class PI0Policy(PreTrainedPolicy):
def __init__(
self,
config: PI0Config,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
@@ -239,6 +253,99 @@ class PI0Policy(PreTrainedPolicy):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
"""
Transform state dict keys to match expected model structure.
Transformations:
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
model.paligemma_with_expert.paligemma.lm_head
- model.paligemma_with_expert.paligemma.language_model.model ->
model.paligemma_with_expert.paligemma.model.language_model
- model.paligemma_with_expert.paligemma.vision_tower ->
model.paligemma_with_expert.paligemma.model.vision_tower
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
model.paligemma_with_expert.paligemma.model.multi_modal_projector
Also handles tied weights between lm_head.weight and
embed_tokens.weight.
"""
import re
transformed_dict = {}
transformations = [
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
".paligemma_with_expert.paligemma.lm_head",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
".paligemma_with_expert.paligemma.model.language_model",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
".paligemma_with_expert.paligemma.model.vision_tower",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
".paligemma_with_expert.paligemma.model.multi_modal_projector",
),
]
for key, value in state_dict.items():
new_key = key
for pattern, replacement in transformations:
new_key = pattern.sub(replacement, new_key)
transformed_dict[new_key] = value
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
lm_head_key = None
embed_tokens_key = None
for key in transformed_dict:
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
lm_head_key = key
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
embed_tokens_key = key
if lm_head_key and embed_tokens_key:
break
if lm_head_key and not embed_tokens_key:
embed_tokens_key = lm_head_key.replace(
".lm_head.weight", ".model.language_model.embed_tokens.weight"
)
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
elif embed_tokens_key and not lm_head_key:
lm_head_key = embed_tokens_key.replace(
".model.language_model.embed_tokens.weight", ".lm_head.weight"
)
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
return transformed_dict
@classmethod
def _load_as_safetensor(
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
) -> "PI0Policy":
"""Override to apply key transformations before loading."""
from safetensors.torch import load_file
init_logging()
# Load the state dict from file safely
state_dict = load_file(model_file, device=map_location)
# Apply key transformations
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
# Load the transformed state dict
msg = model.load_state_dict(transformed_state_dict, strict=strict)
# Log message
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
return model
def get_optim_params(self) -> dict:
return self.parameters()
@@ -270,13 +377,14 @@ class PI0Policy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
@@ -286,6 +394,8 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -300,10 +410,12 @@ class PI0Policy(PreTrainedPolicy):
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
@@ -370,6 +482,26 @@ class PI0Policy(PreTrainedPolicy):
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
@@ -435,7 +567,7 @@ class PI0FlowMatching(nn.Module):
└──────────────────────────────┘
"""
def __init__(self, config: PI0Config):
def __init__(self, config):
super().__init__()
self.config = config

View File

@@ -1,121 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.processor.rename_processor import RenameProcessor
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one.
This is required for the PaliGemma tokenizer.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the features."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def make_pi0_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessor(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessor(device=config.device),
]
output_steps: list[ProcessorStep] = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -58,6 +58,7 @@ from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -145,6 +146,14 @@ class PI0FASTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
@@ -212,6 +221,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
@@ -224,6 +235,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -236,6 +249,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict

View File

@@ -1,52 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_pi0fast_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -28,6 +28,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.policies.normalize import NormalizeBuffer
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.policies.utils import get_device_from_parameters
@@ -44,6 +45,7 @@ class SACPolicy(
def __init__(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__(config)
config.validate_features()
@@ -51,6 +53,7 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features["action"].shape[0]
self._init_normalization(dataset_stats)
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
@@ -85,7 +88,8 @@ class SACPolicy(
observations_features = None
if self.shared_encoder and self.actor.encoder.has_images:
observations_features = self.actor.encoder.get_cached_image_features(batch)
# Cache and normalize image features
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features)
@@ -387,12 +391,28 @@ class SACPolicy(
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_normalization(self, dataset_stats):
"""Initialize input/output normalization modules."""
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
if self.config.dataset_stats is not None:
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
self.normalize_inputs = NormalizeBuffer(
self.config.input_features, self.config.normalization_mapping, params
)
stats = dataset_stats or params
self.normalize_targets = NormalizeBuffer(
self.config.output_features, self.config.normalization_mapping, stats
)
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
self.encoder_critic
if self.shared_encoder
else SACObservationEncoder(self.config, self.normalize_inputs)
)
def _init_critics(self, continuous_action_dim):
@@ -404,7 +424,9 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
self.critic_ensemble = CriticEnsemble(
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
@@ -412,7 +434,9 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target = CriticEnsemble(
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
@@ -466,9 +490,10 @@ class SACPolicy(
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig) -> None:
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self._init_image_layers()
self._init_state_layers()
self._compute_output_dim()
@@ -543,10 +568,11 @@ class SACObservationEncoder(nn.Module):
def forward(
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
) -> Tensor:
obs = self.input_normalization(obs)
parts = []
if self.has_images:
if cache is None:
cache = self.get_cached_image_features(obs)
cache = self.get_cached_image_features(obs, normalize=False)
parts.append(self._encode_images(cache, detach))
if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"]))
@@ -559,7 +585,7 @@ class SACObservationEncoder(nn.Module):
"No parts to concatenate, you should have at least one image or environment state or state"
)
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
"""Extract and optionally cache image features from observations.
This function processes image observations through the vision encoder once and returns
@@ -571,17 +597,26 @@ class SACObservationEncoder(nn.Module):
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
- Caching these features can provide 2-4x speedup in training and inference
Normalization behavior:
- When called from inside forward(): set normalize=False since inputs are already normalized
- When called from outside forward(): set normalize=True to ensure proper input normalization
Usage patterns:
- Called in select_action()
- Called in select_action() with normalize=True
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
- Called internally by forward()
- Called internally by forward() with normalize=False
Args:
obs: Dictionary of observation tensors containing image keys
normalize: Whether to normalize observations before encoding
Set to True when calling directly from outside the encoder's forward method
Set to False when calling from within forward() where inputs are already normalized
Returns:
Dictionary mapping image keys to their corresponding encoded features
"""
if normalize:
obs = self.input_normalization(obs)
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
out = self.image_encoder(batched)
chunks = torch.chunk(out, len(self.image_keys), dim=0)
@@ -712,6 +747,7 @@ class CriticEnsemble(nn.Module):
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
output_normalization (nn.Module): normalization layer for actions.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
@@ -721,11 +757,13 @@ class CriticEnsemble(nn.Module):
self,
encoder: SACObservationEncoder,
ensemble: list[CriticHead],
output_normalization: nn.Module,
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
def forward(
@@ -737,6 +775,11 @@ class CriticEnsemble(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = self.encoder(observations, cache=observation_features)

View File

@@ -1,53 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_sac_processor(
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -20,6 +20,7 @@ import torch
from torch import Tensor, nn
from lerobot.constants import OBS_IMAGE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -107,12 +108,22 @@ class Classifier(PreTrainedPolicy):
def __init__(
self,
config: RewardClassifierConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
@@ -236,6 +247,10 @@ class Classifier(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)

View File

@@ -1,42 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.processor import (
DeviceProcessor,
IdentityProcessor,
NormalizerProcessor,
RobotProcessor,
)
def make_classifier_processor(
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessor(device=config.device),
]
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
steps=output_steps, name="classifier_postprocessor"
)

View File

@@ -53,13 +53,21 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
"""
import math
import os
import re
from collections import deque
import safetensors
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
@@ -68,6 +76,102 @@ from lerobot.policies.utils import (
)
from lerobot.utils.utils import get_safe_dtype
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
normalisation-buffer key.
"""
return _VARIANT_RE.sub(".buffer_", k)
def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]:
"""
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
• If several variant keys collapse to the same canonical name we keep the
first one and log the collision.
• Returns the new dict + a list of entries that could not be matched.
"""
out, collisions, unmatched = {}, {}, []
for k, v in checkpoint.items():
canon = canonicalise(k)
if canon in ref_keys:
if canon in out: # duplicate after collapsing
collisions.setdefault(canon, []).append(k)
else:
out[canon] = v
else:
unmatched.append(k)
if verbose:
for canon, variants in collisions.items():
print(f"[standardise_state_dict] '{canon}'{variants}")
if unmatched:
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.
Args:
checkpoint (dict): The checkpoint dictionary.
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
Returns:
dict: The modified checkpoint with renamed keys.
"""
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
new_checkpoint = {}
for k, v in checkpoint.items():
for old_key, new_key in rename_dict.items():
if old_key in k:
k = k.replace(old_key, new_key)
new_checkpoint[k] = v
return new_checkpoint
def load_smolvla(
model: torch.nn.Module,
filename: str | os.PathLike,
*,
device: str = "cpu",
checkpoint_keys_mapping: str = "",
) -> torch.nn.Module:
state_dict = safetensors.torch.load_file(filename, device=device)
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
raise RuntimeError(
"SmolVLA %d missing / %d unexpected keys",
len(missing),
len(unexpected),
)
return model
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
@@ -222,17 +326,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
def __init__(
self,
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
@@ -242,6 +357,23 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
@classmethod
def _load_as_safetensor(
cls,
model: "SmolVLAPolicy",
model_file: str,
map_location: str,
strict: bool,
):
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return load_smolvla(
model,
model_file,
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
)
def get_optim_params(self) -> dict:
return self.parameters()
@@ -257,8 +389,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
@@ -266,6 +397,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -275,6 +408,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
return batch
@torch.no_grad()
@@ -315,11 +450,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
@@ -383,6 +518,30 @@ class SmolVLAPolicy(PreTrainedPolicy):
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:

View File

@@ -1,110 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
def make_smolvla_processor(
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
SmolVLANewLineProcessor(),
TokenizerProcessor(
tokenizer_name=config.vlm_model_name,
padding=config.pad_language_to,
padding_side="right",
max_length=config.tokenizer_max_length,
),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Adds nothing to the features."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}

View File

@@ -36,6 +36,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
@@ -62,19 +63,26 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig
name = "tdmpc"
def __init__(
self,
config: TDMPCConfig,
):
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
@@ -129,6 +137,7 @@ class TDMPCPolicy(PreTrainedPolicy):
actions = torch.clamp(actions, -1, +1)
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -138,12 +147,11 @@ class TDMPCPolicy(PreTrainedPolicy):
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch)
@@ -312,9 +320,11 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}

View File

@@ -1,52 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_tdmpc_processor(
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -28,6 +28,7 @@ import torchvision
from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -47,6 +48,7 @@ class VQBeTPolicy(PreTrainedPolicy):
def __init__(
self,
config: VQBeTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -59,6 +61,14 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.vqbet = VQBeTModel(config)
self.reset()
@@ -118,6 +128,7 @@ class VQBeTPolicy(PreTrainedPolicy):
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -131,12 +142,10 @@ class VQBeTPolicy(PreTrainedPolicy):
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch)
@@ -156,8 +165,10 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
# loss: total loss of training RVQ

View File

@@ -1,53 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
)
def make_vqbet_processor(
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -14,22 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .batch_processor import ToBatchProcessor
from .delta_action_processor import MapDeltaActionToRobotAction
from .device_processor import DeviceProcessor
from .hil_processor import (
AddTeleopActionAsComplimentaryData,
AddTeleopEventsAsInfo,
GripperPenaltyProcessor,
ImageCropResizeProcessor,
InterventionActionProcessor,
Numpy2TorchActionProcessor,
RewardClassifierProcessor,
TimeLimitProcessor,
Torch2NumpyActionProcessor,
)
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import VanillaObservationProcessor
from .pipeline import (
ActionProcessor,
@@ -46,39 +32,22 @@ from .pipeline import (
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
from .tokenizer_processor import TokenizerProcessor
__all__ = [
"ActionProcessor",
"AddTeleopActionAsComplimentaryData",
"AddTeleopEventsAsInfo",
"DeviceProcessor",
"DoneProcessor",
"MapDeltaActionToRobotAction",
"EnvTransition",
"GripperPenaltyProcessor",
"IdentityProcessor",
"ImageCropResizeProcessor",
"InfoProcessor",
"InterventionActionProcessor",
"JointVelocityProcessor",
"MapDeltaActionToRobotAction",
"MotorCurrentProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
"hotswap_stats",
"ObservationProcessor",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",
"RewardClassifierProcessor",
"RewardProcessor",
"RobotProcessor",
"ToBatchProcessor",
"TokenizerProcessor",
"TimeLimitProcessor",
"Numpy2TorchActionProcessor",
"Torch2NumpyActionProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",

View File

@@ -1,139 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class ToBatchProcessor:
"""Processor that adds batch dimensions to observations and actions when needed.
This processor ensures that observations and actions have proper batch dimensions for model processing:
- For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For task field in complementary data:
Wraps string task in a list to add batch dimension
(task must be a string or list of strings)
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
The processor only modifies tensors that need batching and leaves already
batched tensors unchanged.
Example:
```python
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Task: "pick_cube" -> ["pick_cube"]
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._process_observation(transition)
self._process_action(transition)
self._process_complementary_data(transition)
return transition
def _process_observation(self, transition: EnvTransition) -> None:
"""Process observation component in-place, adding batch dimensions where needed."""
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
state_value = observation[state_key]
if isinstance(state_value, Tensor) and state_value.dim() == 1:
observation[state_key] = state_value.unsqueeze(0)
# Process single image observation - add batch dim if 3D
if OBS_IMAGE in observation:
image_value = observation[OBS_IMAGE]
if isinstance(image_value, Tensor) and image_value.dim() == 3:
observation[OBS_IMAGE] = image_value.unsqueeze(0)
# Process multiple image observations - add batch dim if 3D
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
def _process_action(self, transition: EnvTransition) -> None:
"""Process action component in-place, adding batch dimension if needed."""
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
transition[TransitionKey.ACTION] = action.unsqueeze(0)
def _process_complementary_data(self, transition: EnvTransition) -> None:
"""Process complementary data in-place, handling task field batching."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
# Process index field - add batch dim if 0D
if "index" in complementary_data:
index_value = complementary_data["index"]
if isinstance(index_value, Tensor) and index_value.dim() == 0:
complementary_data["index"] = index_value.unsqueeze(0)
# Process task_index field - add batch dim if 0D
if "task_index" in complementary_data:
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,225 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Sequence
from copy import deepcopy
from typing import Any
import numpy as np
import torch
from scipy.spatial.transform import Rotation
from .pipeline import EnvTransition, TransitionKey
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
if isinstance(x, torch.Tensor):
return x
if isinstance(x, np.ndarray):
# Keep images (uint8 HWC) and python objects as-is
if x.dtype == np.uint8 or x.dtype == np.object_:
return x
# Scalars/arrays to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
# Anything else to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
def _from_tensor(x: Any):
if isinstance(x, torch.Tensor):
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
return x
def _is_image(arr: Any) -> bool:
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
state, images = {}, {}
for k, v in obs.items():
if _is_image(v):
images[k] = v
else:
state[k] = v
return state, images
def make_obs_act_transition(
*, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None
) -> EnvTransition:
return {
TransitionKey.OBSERVATION: {} if obs is None else obs,
TransitionKey.ACTION: {} if act is None else act,
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {},
TransitionKey.REWARD: None,
TransitionKey.DONE: None,
TransitionKey.TRUNCATED: None,
}
def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
"""
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
"""
act_dict: dict[str, Any] = {}
for k, v in action.items():
# Check if the value is a type that should not be converted to a tensor.
if isinstance(v, (Rotation, dict)):
act_dict[f"action.{k}"] = v
continue
arr = np.array(v) if np.isscalar(v) else v
act_dict[f"action.{k}"] = _to_tensor(arr)
return make_obs_act_transition(act=act_dict)
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition:
"""
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
"""
state, images = _split_obs_to_state_and_images(observation)
obs_dict: dict[str, Any] = {}
for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"observation.state.{k}"] = _to_tensor(arr)
for cam, img in images.items():
obs_dict[f"observation.images.{cam}"] = img
return make_obs_act_transition(obs=obs_dict)
def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
"""
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
"""
out: dict[str, Any] = {}
action_dict = transition.get(TransitionKey.ACTION) or {}
for k, v in action_dict.items():
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")):
out_key = k[len("action.") :] # Strip the 'action.' prefix.
out[out_key] = float(v)
return out
def to_dataset_frame(
transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict]
) -> dict[str, any]:
"""
Converts a single EnvTransition or an iterable of them into a flat,
dataset-friendly dictionary for training or evaluation, according to
the provided `features` spec.
Args:
transitions_or_transition: Either a single EnvTransition dict
or an iterable of them (which will be merged).
features (dict[str, dict]):
A feature specification dictionary:
- 'action': dict with 'names': list of action feature names
- 'observation.state': dict with 'names': list of state feature names
- keys starting with 'observation.images.' are passed through
Returns:
batch (dict[str, any]): Flat dictionary containing:
- numpy arrays for "observation.state" and "action"
- any image tensors defined in features
- next.{reward,done,truncated}
- info dict
- *_is_pad flags and task from complementary_data
"""
action_names = features.get("action", {}).get("names", [])
obs_state_names = features.get("observation.state", {}).get("names", [])
image_keys = [k for k in features if k.startswith("observation.images.")]
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
out = deepcopy(base)
for key in (
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
):
if other.get(key):
out.setdefault(key, {}).update(deepcopy(other[key]))
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
if k in other:
out[k] = other[k]
return out
def _ensure_transition(obj) -> EnvTransition:
# single transition
if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj):
return obj
# iterable of transitions
if isinstance(obj, Iterable):
items = list(obj)
if not items:
return {}
acc = items[0]
for t in items[1:]:
acc = _merge(acc, t)
return acc
raise TypeError("Expected EnvTransition or iterable of them")
tr = _ensure_transition(transitions_or_transition)
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
act = tr.get(TransitionKey.ACTION, {}) or {}
batch: dict[str, any] = {}
# Images passthrough
for k in image_keys:
if k in obs:
batch[k] = obs[k]
# Observation.state vector
if obs_state_names:
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names]
batch["observation.state"] = np.asarray(vals, dtype=np.float32)
# Action vector
if action_names:
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names]
batch["action"] = np.asarray(vals, dtype=np.float32)
# Next.* fields
if tr.get(TransitionKey.REWARD) is not None:
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD])
if tr.get(TransitionKey.DONE) is not None:
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE])
if tr.get(TransitionKey.TRUNCATED) is not None:
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED])
# Complementary data flags and task
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if comp:
# pad flags
for k, v in comp.items():
if k.endswith("_is_pad"):
batch[k] = v
# task label
if comp.get("task") is not None:
batch["task"] = comp["task"]
return batch

View File

@@ -1,125 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
@dataclass
class MapDeltaActionToRobotAction(ActionProcessor):
"""
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
for use with inverse kinematics processors.
Expected input ACTION keys:
{
"action.delta_x": float,
"action.delta_y": float,
"action.delta_z": float,
"action.gripper": float (optional),
}
Output ACTION keys:
{
"action.enabled": bool,
"action.target_x": float,
"action.target_y": float,
"action.target_z": float,
"action.target_wx": float,
"action.target_wy": float,
"action.target_wz": float,
"action.gripper": float,
}
"""
# Scale factors for delta movements
position_scale: float = 1.0
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
gripper_deadzone: float = 0.1 # Threshold for gripper activation
_prev_enabled: bool = field(default=False, init=False, repr=False)
def action(self, action: dict | Tensor | None) -> dict:
if action is None:
return {}
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
if isinstance(action, dict):
delta_x = action.pop("action.delta_x", 0.0)
delta_y = action.pop("action.delta_y", 0.0)
delta_z = action.pop("action.delta_z", 0.0)
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
else:
delta_x = action[0].item()
delta_y = action[1].item()
delta_z = action[2].item()
gripper = action[3].item()
# Determine if the teleoperator is actively providing input
# Consider enabled if any significant movement delta is detected
position_magnitude = abs(delta_x) + abs(delta_y) + abs(delta_z)
enabled = position_magnitude > 1e-6 # Small threshold to avoid noise
# Scale the deltas appropriately
scaled_delta_x = float(delta_x) * self.position_scale
scaled_delta_y = float(delta_y) * self.position_scale
scaled_delta_z = float(delta_z) * self.position_scale
# For gamepad/keyboard, we don't have rotation input, so set to 0
# These could be extended in the future for more sophisticated teleoperators
target_wx = 0.0
target_wy = 0.0
target_wz = 0.0
# Update action with robot target format
action = {
"action.enabled": enabled,
"action.target_x": scaled_delta_x,
"action.target_y": scaled_delta_y,
"action.target_z": scaled_delta_z,
"action.target_wx": target_wx,
"action.target_wy": target_wy,
"action.target_wz": target_wz,
"action.gripper": float(gripper),
}
self._prev_enabled = enabled
return action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transform features to match output format."""
# Update features to reflect the new action format
features.update(
{
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
}
)
return features
def reset(self):
self._prev_enabled = False

View File

@@ -19,80 +19,24 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
"""Processes transitions by moving tensors to the specified device.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned. It can also convert
floating-point tensors to a specified dtype while preserving non-float types
(int, long, bool, etc.).
specified device (CPU or GPU) before they are returned.
"""
device: str = "cpu"
float_dtype: str | None = None
_device: torch.device | None = None
device: torch.device = "cpu"
def __post_init__(self):
self._device = get_safe_torch_device(self.device)
self.device = self._device.type
self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
if self.float_dtype is not None:
dtype_mapping = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}
if self.float_dtype not in dtype_mapping:
available_dtypes = list(dtype_mapping.keys())
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
)
self._target_float_dtype = dtype_mapping[self.float_dtype]
else:
self._target_float_dtype = None
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process a tensor by moving to device and optionally converting float dtype.
If the tensor is already on a GPU and we're configured for a GPU, it preserves
that GPU placement (useful for multi-GPU training with Accelerate).
Otherwise, it moves to the configured device.
"""
# Determine target device
if tensor.is_cuda and self._device.type == "cuda":
# Both tensor and target are on GPU - preserve tensor's GPU placement
# This handles multi-GPU scenarios where Accelerate has already placed
# tensors on the correct GPU for each process
target_device = tensor.device
else:
# Either tensor is on CPU, or we're configured for CPU
# In both cases, use the configured device
target_device = self._device
# Only move if necessary
if tensor.device != target_device:
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
# Convert float dtype if specified and tensor is floating point
if self._target_float_dtype is not None and tensor.is_floating_point():
tensor = tensor.to(dtype=self._target_float_dtype)
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
@@ -101,7 +45,7 @@ class DeviceProcessor:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
@@ -109,54 +53,30 @@ class DeviceProcessor:
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = self._process_tensor(action)
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = self._process_tensor(done)
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)
# Process complementary data tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is not None:
new_complementary_data = {}
# Process all items in complementary_data
for key, value in complementary_data.items():
if isinstance(value, torch.Tensor):
new_complementary_data[key] = self._process_tensor(value)
else:
new_complementary_data[key] = value
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device, "float_dtype": self.float_dtype}
return {"device": self.device}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,418 +0,0 @@
import time
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
InfoProcessor,
ObservationProcessor,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
GRIPPER_KEY = "gripper"
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@dataclass
class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
"""Add teleoperator action to transition complementary data."""
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict | None) -> dict:
complementary_data = {} if complementary_data is None else dict(complementary_data)
complementary_data["teleop_action"] = self.teleop_device.get_action()
return complementary_data
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
class AddTeleopEventsAsInfo(InfoProcessor):
"""Add teleoperator control events to transition info."""
teleop_device: Teleoperator
def info(self, info: dict | None) -> dict:
info = {} if info is None else dict(info)
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
info.update(teleop_events)
return info
@ProcessorStepRegistry.register("torch2numpy_action_processor")
@dataclass
class Torch2NumpyActionProcessor(ActionProcessor):
"""Convert PyTorch tensor actions to NumPy arrays."""
squeeze_batch_dim: bool = True
def action(self, action: torch.Tensor | None) -> np.ndarray | None:
if action is None:
return None
if not isinstance(action, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
numpy_action = action.detach().cpu().numpy()
# Remove batch dimensions but preserve action dimensions
# Only squeeze if there's a batch dimension (first dim == 1)
if (
self.squeeze_batch_dim
and numpy_action.shape
and len(numpy_action.shape) > 1
and numpy_action.shape[0] == 1
):
numpy_action = numpy_action.squeeze(0)
return numpy_action
@ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass
class Numpy2TorchActionProcessor(ActionProcessor):
"""Convert NumPy array action to PyTorch tensor."""
def action(self, action: np.ndarray | None) -> torch.Tensor | None:
if action is None:
return None
if not isinstance(action, np.ndarray):
raise TypeError(
f"Expected np.ndarray or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
torch_action = torch.from_numpy(action)
return torch_action
@ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass
class ImageCropResizeProcessor(ObservationProcessor):
"""Crop and resize image observations."""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
if self.resize_size is None and not self.crop_params_dict:
return observation
new_observation = dict(observation)
# Process all image keys in the observation
for key in observation:
if "image" not in key:
continue
image = observation[key]
device = image.device
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
if device.type == "mps":
image = image.cpu()
# Crop if crop params are provided for this key
if self.crop_params_dict is not None and key in self.crop_params_dict:
crop_params = self.crop_params_dict[key]
image = F.crop(image, *crop_params)
if self.resize_size is not None:
image = F.resize(image, self.resize_size)
image = image.clamp(0.0, 1.0)
new_observation[key] = image.to(device)
return new_observation
def get_config(self) -> dict[str, Any]:
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if self.resize_size is None:
return features
for key in features:
if "image" in key:
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
return features
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessor:
"""Track episode steps and enforce time limits."""
max_episode_steps: int
current_step: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"max_episode_steps": self.max_episode_steps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessor:
"""Apply penalty for inappropriate gripper usage."""
penalty: float = -0.01
max_gripper_pos: float = 30.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Calculate gripper penalty and add to complementary data."""
action = transition.get(TransitionKey.ACTION)
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or action is None:
return transition
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return transition
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Add penalty information to complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data["discrete_penalty"] = gripper_penalty
# Create new transition with updated complementary data
new_transition = transition.copy()
existing_comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
existing_comp_data.update(new_complementary_data)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = existing_comp_data # type: ignore[misc]
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessor:
"""Handle human intervention actions and episode termination."""
use_gripper: bool = False
terminate_on_success: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
# Get intervention signals from complementary data
info = transition.get(TransitionKey.INFO, {})
teleop_action = info.get("teleop_action", {})
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
success = info.get(TeleopEvents.SUCCESS, False)
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
new_transition = transition.copy()
# Override action if intervention is active
if is_intervention and teleop_action is not None:
if isinstance(teleop_action, dict):
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get("action.delta_x", 0.0),
teleop_action.get("action.delta_y", 0.0),
teleop_action.get("action.delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get("gripper", 1.0))
elif isinstance(teleop_action, np.ndarray):
action_list = teleop_action.tolist()
else:
action_list = teleop_action
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
new_transition[TransitionKey.ACTION] = teleop_action_tensor
# Handle episode termination
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
self.terminate_on_success and success
)
new_transition[TransitionKey.REWARD] = float(success)
# Update info with intervention metadata
info = new_transition.get(TransitionKey.INFO, {})
info[TeleopEvents.IS_INTERVENTION] = is_intervention
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
info[TeleopEvents.SUCCESS] = success
new_transition[TransitionKey.INFO] = info
# Update complementary data with teleop action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"use_gripper": self.use_gripper,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
class RewardClassifierProcessor:
"""Apply reward classification to image observations."""
pretrained_path: str | None = None
device: str = "cpu"
success_threshold: float = 0.5
success_reward: float = 1.0
terminate_on_success: bool = True
reward_classifier: Any = None
def __post_init__(self):
"""Initialize the reward classifier after dataclass initialization."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
self.reward_classifier.to(self.device)
self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None:
return transition
# Extract images from observation
images = {key: value for key, value in observation.items() if "image" in key}
if not images:
return transition
# Run reward classifier
start_time = time.perf_counter()
with torch.inference_mode():
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
classifier_frequency = 1 / (time.perf_counter() - start_time)
# Calculate reward and termination
reward = transition.get(TransitionKey.REWARD, 0.0)
terminated = transition.get(TransitionKey.DONE, False)
if success == 1.0:
reward = self.success_reward
if self.terminate_on_success:
terminated = True
# Update transition
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward
new_transition[TransitionKey.DONE] = terminated
# Update info with classifier frequency
info = new_transition.get(TransitionKey.INFO, {})
info["reward_classifier_frequency"] = classifier_frequency
new_transition[TransitionKey.INFO] = info
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"device": self.device,
"success_threshold": self.success_threshold,
"success_reward": self.success_reward,
"terminate_on_success": self.terminate_on_success,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,116 +0,0 @@
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
from lerobot.robots import Robot
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor:
"""Add joint velocity information to observations."""
joint_velocity_limits: float = 100.0
dt: float = 1.0 / 10
num_dof: int | None = None
last_joint_positions: torch.Tensor | None = None
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get("observation.state")
if current_positions is None:
return observation
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
self.last_joint_positions = current_positions.clone()
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
self.last_joint_positions = current_positions.clone()
# Extend observation with velocities
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
return new_observation
def get_config(self) -> dict[str, Any]:
return {
"joint_velocity_limits": self.joint_velocity_limits,
"dt": self.dt,
}
def reset(self) -> None:
self.last_joint_positions = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if "observation.state" in features and self.num_dof is not None:
from lerobot.configs.types import PolicyFeature
original_feature = features["observation.state"]
# Double the shape to account for positions + velocities
new_shape = (original_feature.shape[0] + self.num_dof,) + original_feature.shape[1:]
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
return features
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessor(ObservationProcessor):
"""Add motor current information to observations."""
robot: Robot | None = None
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
# Get current values from robot state
if self.robot is None:
return observation
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
motor_currents = torch.tensor(
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
dtype=torch.float32,
).unsqueeze(0)
current_state = observation.get("observation.state")
if current_state is None:
return observation
extended_state = torch.cat([current_state, motor_currents], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
return new_observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if "observation.state" in features and self.robot is not None:
from lerobot.configs.types import PolicyFeature
original_feature = features["observation.state"]
# Add motor current dimensions to the original state shape
num_motors = 0
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
if num_motors > 0:
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
return features

View File

@@ -1,502 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
This script:
1. Loads an existing pretrained policy model
2. Extracts normalization statistics from the model
3. Creates both preprocessor and postprocessor:
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
- Postprocessor: unnormalizes outputs (actions) for inference
4. Removes normalization layers from the model state_dict
5. Saves the new model and both processors
Usage:
python src/lerobot/processor/migrate_policy_normalization.py \
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
--policy-type act \
--push-to-hub
"""
import argparse
import importlib
import json
import os
from copy import deepcopy
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.device_processor import DeviceProcessor
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor.rename_processor import RenameProcessor
# Policy type to class mapping
POLICY_CLASSES = {
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
}
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Extract normalization statistics from model state_dict."""
stats = {}
# Define patterns to match and their prefixes to remove
normalization_patterns = [
"normalize_inputs.buffer_",
"unnormalize_outputs.buffer_",
"normalize_targets.buffer_",
"normalize.", # Must come after normalize_* patterns
"unnormalize.", # Must come after unnormalize_* patterns
"input_normalizer.",
"output_normalizer.",
]
# Process each key in state_dict
for key, tensor in state_dict.items():
# Try each pattern
for pattern in normalization_patterns:
if key.startswith(pattern):
# Extract the remaining part after the pattern
remaining = key[len(pattern) :]
parts = remaining.split(".")
# Need at least feature name and stat type
if len(parts) >= 2:
# Last part is the stat type (mean, std, min, max, etc.)
stat_type = parts[-1]
# Everything else is the feature name
feature_name = ".".join(parts[:-1]).replace("_", ".")
# Add to stats
if feature_name not in stats:
stats[feature_name] = {}
stats[feature_name][stat_type] = tensor.clone()
# Only process the first matching pattern
break
return stats
def detect_features_and_norm_modes(
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
"""Detect features and normalization modes from config and stats."""
features = {}
norm_modes = {}
# First, check if there's a normalization_mapping in the config
if "normalization_mapping" in config:
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
# Extract normalization modes from config
for feature_name, mode_str in config["normalization_mapping"].items():
# Convert string to NormalizationMode enum
if mode_str == "mean_std":
mode = NormalizationMode.MEAN_STD
elif mode_str == "min_max":
mode = NormalizationMode.MIN_MAX
else:
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
continue
# Determine feature type from feature name
if "image" in feature_name or "visual" in feature_name:
feature_type = FeatureType.VISUAL
elif "state" in feature_name:
feature_type = FeatureType.STATE
elif "action" in feature_name:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
norm_modes[feature_type] = mode
# Try to extract from config
if "features" in config:
for key, feature_config in config["features"].items():
shape = feature_config.get("shape", feature_config.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
# Determine feature type
if "image" in key or "visual" in key:
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE # Default
features[key] = PolicyFeature(feature_type, shape)
# If no features in config, infer from stats
if not features:
for key, stat_dict in stats.items():
# Get shape from any stat tensor
tensor = next(iter(stat_dict.values()))
shape = tuple(tensor.shape)
# Determine feature type based on key
if "image" in key or "visual" in key or "pixels" in key:
feature_type = FeatureType.VISUAL
elif "state" in key or "joint" in key or "position" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
features[key] = PolicyFeature(feature_type, shape)
# If normalization modes weren't in config, determine based on available stats
if not norm_modes:
for key, stat_dict in stats.items():
if key in features:
if "mean" in stat_dict and "std" in stat_dict:
feature_type = features[key].type
if feature_type not in norm_modes:
norm_modes[feature_type] = NormalizationMode.MEAN_STD
elif "min" in stat_dict and "max" in stat_dict:
feature_type = features[key].type
if feature_type not in norm_modes:
norm_modes[feature_type] = NormalizationMode.MIN_MAX
# Default normalization modes if not detected
if FeatureType.VISUAL not in norm_modes:
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
if FeatureType.STATE not in norm_modes:
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
if FeatureType.ACTION not in norm_modes:
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
return features, norm_modes
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Remove normalization layers from state_dict."""
new_state_dict = {}
# Patterns to remove
remove_patterns = [
"normalize_inputs.",
"unnormalize_outputs.",
"normalize_targets.", # Added pattern for target normalization
"normalize.",
"unnormalize.",
"input_normalizer.",
"output_normalizer.",
"normalizer.",
]
for key, tensor in state_dict.items():
should_remove = any(pattern in key for pattern in remove_patterns)
if not should_remove:
new_state_dict[key] = tensor
return new_state_dict
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert features from old format to PolicyFeature objects."""
converted_features = {}
for key, feature_dict in features_dict.items():
# Determine feature type based on key
if "image" in key or "visual" in key:
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
# Get shape from feature dict
shape = feature_dict.get("shape", feature_dict.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
converted_features[key] = PolicyFeature(feature_type, shape)
return converted_features
def load_model_from_hub(
repo_id: str, revision: str = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""Load model state_dict and config from hub."""
# Download files
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
# Load state_dict
state_dict = load_safetensors(safetensors_path)
# Load config
with open(config_path) as f:
config = json.load(f)
with open(train_config_path) as f:
train_config = json.load(f)
return state_dict, config, train_config
def main():
parser = argparse.ArgumentParser(
description="Migrate policy models with normalization layers to new pipeline system"
)
parser.add_argument(
"--pretrained-path",
type=str,
required=True,
help="Path to pretrained model (hub repo or local directory)",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory for migrated model (default: same as pretrained-path)",
)
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
parser.add_argument(
"--hub-repo-id",
type=str,
default=None,
help="Hub repository ID for pushing (default: same as pretrained-path)",
)
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
args = parser.parse_args()
# Load model and config
print(f"Loading model from {args.pretrained_path}...")
if os.path.isdir(args.pretrained_path):
# Local directory
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
with open(os.path.join(args.pretrained_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
train_config = json.load(f)
else:
# Hub repository
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
# Extract normalization statistics
print("Extracting normalization statistics...")
stats = extract_normalization_stats(state_dict)
print(f"Found normalization statistics for: {list(stats.keys())}")
# Detect input features and normalization modes
print("Detecting features and normalization modes...")
features, norm_map = detect_features_and_norm_modes(config, stats)
print(f"Detected features: {list(features.keys())}")
print(f"Normalization modes: {norm_map}")
# Remove normalization layers from state_dict
print("Removing normalization layers from model...")
new_state_dict = remove_normalization_layers(state_dict)
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
if removed_keys:
print(f"Removed {len(removed_keys)} normalization layer keys")
# Determine output path
if args.output_dir:
output_dir = Path(args.output_dir)
else:
if os.path.isdir(args.pretrained_path):
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
else:
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
output_dir.mkdir(parents=True, exist_ok=True)
# Clean up config - remove normalization_mapping field
cleaned_config = dict(config)
if "normalization_mapping" in cleaned_config:
print("Removing 'normalization_mapping' field from config")
del cleaned_config["normalization_mapping"]
policy_type = deepcopy(cleaned_config["type"])
del cleaned_config["type"]
# Instantiate the policy model with cleaned config and load the cleaned state dict
print(f"Instantiating {policy_type} policy model...")
policy_class_path = POLICY_CLASSES[policy_type]
module_path, class_name = policy_class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
policy_class = getattr(module, class_name)
# Create config class instance
config_module_path = module_path.replace("modeling", "configuration")
config_module = importlib.import_module(config_module_path)
# Handle special cases for config class names
config_class_names = {
"act": "ACTConfig",
"diffusion": "DiffusionConfig",
"pi0": "PI0Config",
"pi0fast": "PI0FASTConfig",
"smolvla": "SmolVLAConfig",
"tdmpc": "TDMPCConfig",
"vqbet": "VQBeTConfig",
"sac": "SACConfig",
"classifier": "ClassifierConfig",
}
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
config_class = getattr(config_module, config_class_name)
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
if "input_features" not in cleaned_config:
raise ValueError("Missing mandatory 'input_features' in config")
if "output_features" not in cleaned_config:
raise ValueError("Missing mandatory 'output_features' in config")
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
# Create config instance from cleaned config dict
policy_config = config_class(**cleaned_config)
# Create policy instance - some policies expect dataset_stats
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Now create preprocessor and postprocessor with cleaned_config available
print("Creating preprocessor and postprocessor...")
# The pattern from existing processor factories:
# - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessor for output_features only
# Get features from cleaned_config (now they're PolicyFeature objects)
input_features = cleaned_config.get("input_features", {})
output_features = cleaned_config.get("output_features", {})
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**input_features, **output_features},
norm_map=norm_map,
stats=stats,
),
ToBatchProcessor(),
DeviceProcessor(device=policy_config.device),
]
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
]
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
# Determine hub repo ID if pushing to hub
if args.push_to_hub:
if args.hub_repo_id:
hub_repo_id = args.hub_repo_id
else:
if not os.path.isdir(args.pretrained_path):
# Use same repo with "_migrated" suffix
hub_repo_id = f"{args.pretrained_path}_migrated"
else:
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
else:
hub_repo_id = None
# Save preprocessor and postprocessor to root directory
print(f"Saving preprocessor to {output_dir}...")
preprocessor.save_pretrained(output_dir)
if args.push_to_hub:
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
print(f"Saving postprocessor to {output_dir}...")
postprocessor.save_pretrained(output_dir)
if args.push_to_hub:
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# Save model using the policy's save_pretrained method
print(f"Saving model to {output_dir}...")
policy.save_pretrained(
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
)
# Generate and save model card
print("Generating model card...")
# Get metadata from original config
dataset_repo_id = train_config.get("repo_id", "unknown")
license = config.get("license", "apache-2.0")
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
tags = set(tags).union({"robotics", "lerobot", policy_type})
tags = list(tags)
# Generate model card
card = policy.generate_model_card(
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
)
# Save model card locally
card.save(str(output_dir / "README.md"))
print(f"Model card saved to {output_dir / 'README.md'}")
# Push model card to hub if requested
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=str(output_dir / "README.md"),
path_in_repo="README.md",
repo_id=hub_repo_id,
repo_type="model",
commit_message="Add model card for migrated model",
)
print("Model card pushed to hub")
print("\nMigration complete!")
print(f"Migrated model saved to: {output_dir}")
if args.push_to_hub:
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
@@ -11,7 +10,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor, TransitionKey
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
@@ -116,7 +115,7 @@ class NormalizerProcessor:
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation, normalized_info):
def _normalize_obs(self, observation):
if observation is None:
return None
@@ -129,20 +128,7 @@ class NormalizerProcessor:
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self.features:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
@@ -153,35 +139,16 @@ class NormalizerProcessor:
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
normalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
normalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
return processed
def _normalize_action(self, action, normalized_info):
if action is None:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
def _normalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
@@ -190,42 +157,22 @@ class NormalizerProcessor:
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
normalized_info["action"] = "MEAN_STD"
return (tensor - mean) / (std + self.eps)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
normalized_info["action"] = "MIN_MAX"
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return (tensor - mean) / (std + self.eps)
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Track what was normalized
normalized_info = {}
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info)
action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info)
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add normalization info to complementary data
if normalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["normalized_keys"] = normalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -257,7 +204,7 @@ class NormalizerProcessor:
def reset(self):
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -306,28 +253,14 @@ class UnnormalizerProcessor:
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation, unnormalized_info):
def _unnormalize_obs(self, observation):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self.features:
if key not in processed or key not in self._tensor_stats:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
@@ -335,80 +268,39 @@ class UnnormalizerProcessor:
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
unnormalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
unnormalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
return processed
def _unnormalize_action(self, action, unnormalized_info):
if action is None:
def _unnormalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
unnormalized_info["action"] = "MEAN_STD"
return tensor * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
unnormalized_info["action"] = "MIN_MAX"
return (tensor + 1) / 2 * (max_val - min_val) + min_val
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return tensor * std + mean
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return (tensor + 1) / 2 * (max_val - min_val) + min_val
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Track what was unnormalized
unnormalized_info = {}
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info)
action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info)
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add unnormalization info to complementary data
if unnormalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["unnormalized_keys"] = unnormalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -435,41 +327,5 @@ class UnnormalizerProcessor:
def reset(self):
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
robot_processor = deepcopy(robot_processor)
for step in robot_processor.steps:
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
step: NormalizerProcessor | UnnormalizerProcessor
step.stats = stats
step._tensor_stats = _convert_stats_to_tensors(stats)
return robot_processor
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to the provided mapping.
Args:
stats: The statistics dictionary with structure {feature_key: {stat_name: value}}
rename_map: Dictionary mapping old key names to new key names
Returns:
A new stats dictionary with renamed keys
Example:
>>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}}
>>> rename_map = {"observation.state": "observation.robot_state"}
>>> new_stats = rename_stats(stats, rename_map)
>>> # new_stats will have "observation.robot_state" instead of "observation.state"
"""
renamed_stats = {}
for old_key, sub_stats in stats.items():
# Use the new key if it exists in the rename map, otherwise keep the old key
new_key = rename_map.get(old_key, old_key)
renamed_stats[new_key] = deepcopy(sub_stats)
return renamed_stats

View File

@@ -106,8 +106,9 @@ class VanillaObservationProcessor(ObservationProcessor):
def observation(self, observation):
return self._process_observation(observation)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').

View File

@@ -23,7 +23,7 @@ from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Protocol, TypedDict, runtime_checkable
from typing import Any, Protocol, TypedDict
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
@@ -132,7 +132,6 @@ class ProcessorStepRegistry:
cls._registry.clear()
@runtime_checkable
class ProcessorStep(Protocol):
"""Structural typing interface for a single processor step.
@@ -146,6 +145,7 @@ class ProcessorStep(Protocol):
**Required**:
- ``__call__(transition: EnvTransition) -> EnvTransition``
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
Optional helper protocol:
* ``get_config() -> dict[str, Any]`` User-defined JSON-serializable
@@ -158,8 +158,6 @@ class ProcessorStep(Protocol):
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict
containing torch tensors only.
* ``reset()`` Clear internal buffers at episode boundaries.
* ``transform_features(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
If present, this method will be called to aggregate the dataset features of all steps.
Example separation:
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
@@ -176,7 +174,7 @@ class ProcessorStep(Protocol):
def reset(self) -> None: ...
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
@@ -203,16 +201,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
# Extract padding, task, index, and task_index keys for complementary data
# Extract padding and task keys for complementary data
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
complementary_data = (
{**pad_keys, **task_key, **index_key, **task_index_key}
if pad_keys or task_key or index_key or task_index_key
else {}
)
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
transition: EnvTransition = {
TransitionKey.OBSERVATION: observation,
@@ -239,7 +231,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"info": transition.get(TransitionKey.INFO, {}),
}
# Add padding, task, index, and task_index data from complementary_data
# Add padding and task data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
@@ -248,12 +240,6 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
if "task" in complementary_data:
batch["task"] = complementary_data["task"]
if "index" in complementary_data:
batch["index"] = complementary_data["index"]
if "task_index" in complementary_data:
batch["task_index"] = complementary_data["task_index"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
@@ -356,10 +342,7 @@ class RobotProcessor(ModelHubMixin):
hook(idx, current_transition)
# Convert back to original format if needed
if called_with_batch or self.to_output is not _default_transition_to_batch:
return self.to_output(current_transition)
else:
return current_transition
return self.to_output(current_transition) if called_with_batch else current_transition
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
"""Prepare and validate transition data for processing.
@@ -592,9 +575,10 @@ class RobotProcessor(ModelHubMixin):
if config_filename is None:
# Try common config names
common_names = [
"robot_processor.json",
"robot_preprocessor.json",
"robot_postprocessor.json",
"processor.json",
"preprocessor.json",
"postprocessor.json",
"robotprocessor.json",
]
config_path = None
for name in common_names:
@@ -824,15 +808,23 @@ class RobotProcessor(ModelHubMixin):
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
fc = getattr(step, "feature_contract", None)
if not callable(fc):
raise TypeError(
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
)
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Apply ALL steps in order. Only if a step has a features method, it will be called.
We aggregate the dataset features of all steps.
Apply ALL steps in order. Each step must implement
feature_contract(features) and return a dict (full or incremental schema).
"""
features: dict[str, PolicyFeature] = deepcopy(initial_features)
for _, step in enumerate(self.steps):
out = step.transform_features(features)
out = step.feature_contract(features)
if not isinstance(out, dict):
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
features = out
return features
@@ -892,7 +884,7 @@ class ObservationProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -952,7 +944,7 @@ class ActionProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1011,7 +1003,7 @@ class RewardProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1075,7 +1067,7 @@ class DoneProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1135,7 +1127,7 @@ class TruncatedProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1200,7 +1192,7 @@ class InfoProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1246,7 +1238,7 @@ class ComplementaryDataProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1268,5 +1260,5 @@ class IdentityProcessor:
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -43,7 +43,7 @@ class RenameProcessor(ObservationProcessor):
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
- Keys not in `rename_map` remain unchanged.

View File

@@ -1,275 +0,0 @@
"""
Tokenizer processor for handling text tokenization in robot transitions.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
else:
AutoTokenizer = None
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessor:
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
This processor handles tokenization of task strings found in the complementary_data
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
to the observation data for model processing while preserving the original task string.
The processor supports both single strings and lists of strings as task inputs.
Args:
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
tokenizer: A tokenizer object (e.g., from transformers library) that implements
the __call__ method. If provided, tokenizer_name is ignored. This parameter
is not serialized and must be provided via overrides when loading.
max_length: Maximum sequence length for tokenization. Defaults to 512.
task_key: Key in complementary_data containing the task text. Defaults to "task".
padding: Padding strategy for tokenization. Defaults to "max_length".
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
Examples:
Using tokenizer name (auto-loaded):
```python
processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128)
```
Using custom tokenizer object:
```python
from transformers import AutoTokenizer
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128)
```
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
# Internal tokenizer instance (not serialized)
_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self._tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoTokenizer is None:
raise ImportError("AutoTokenizer is not available")
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
def get_task(self, transition: EnvTransition) -> list[str] | None:
"""Extract and normalize task from complementary data.
Args:
transition: Input transition containing complementary_data.
Returns:
List of task strings if task is present, None otherwise.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
if self.task_key not in complementary_data:
return None
task = complementary_data[self.task_key]
if task is None:
return None
# Convert to list of strings
if isinstance(task, str):
return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
return task
return None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process the transition by tokenizing the task text.
Args:
transition: Input transition containing complementary_data with task text.
Returns:
Modified transition with tokenized task added to observation.
Raises:
ValueError: If tokenizer initialization failed.
"""
task = self.get_task(transition)
if task is None:
return transition
# Tokenize the task (creates CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect device from existing tensors in the transition
target_device = self._detect_device(transition)
# Move tokenized tensors to match the device of other data
if target_device is not None:
tokenized_prompt = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_prompt.items()
}
# Get or create observation dict
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
observation = {}
else:
observation = dict(observation) # Make a copy
# Add tokenized data to observation
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
dtype=torch.bool
)
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
return transition
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
"""Detect device from existing tensors in the transition.
This allows the tokenized tensors to match the device of other data,
which is especially important for multi-GPU training with Accelerate.
Args:
transition: The transition to search for existing tensors.
Returns:
The device of the first tensor found, or None if no tensors exist.
"""
# Check observation tensors first (most likely to exist)
observation = transition.get(TransitionKey.OBSERVATION)
if observation:
for value in observation.values():
if isinstance(value, torch.Tensor):
return value.device
# Check action tensor
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
return action.device
# Check other tensor fields
for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
value = transition.get(key)
if isinstance(value, torch.Tensor):
return value.device
# Check complementary data for tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
for value in complementary_data.values():
if isinstance(value, torch.Tensor):
return value.device
return None # No tensors found, keep on CPU
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.
Args:
text: Text string or list of strings to tokenize.
Returns:
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
"""
return self._tokenizer(
text,
max_length=self.max_length,
truncation=self.truncation,
padding=self.padding,
padding_side=self.padding_side,
return_tensors="pt",
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.
Note: Only tokenizer_name is saved, not the tokenizer object itself.
When loading, provide the tokenizer via overrides if needed.
"""
config = {
"max_length": self.max_length,
"task_key": self.task_key,
"padding_side": self.padding_side,
"padding": self.padding,
"truncation": self.truncation,
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
if self.tokenizer_name is not None:
config["tokenizer_name"] = self.tokenizer_name
return config
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
Args:
features: Input feature dictionary.
Returns:
Updated feature dictionary with tokenized task features added.
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
tokens_key = f"{OBS_LANGUAGE}.tokens"
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
if tokens_key not in features:
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if attention_mask_key not in features:
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
return features

View File

@@ -59,7 +59,7 @@ lerobot-record \
import logging
import time
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
@@ -72,19 +72,9 @@ from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_processor
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import (
to_dataset_frame,
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
)
from lerobot.processor.normalize_processor import rename_stats
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -158,8 +148,6 @@ class DatasetRecordConfig:
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
if self.single_task is None:
@@ -198,36 +186,6 @@ class RecordConfig:
return ["policy"]
""" --------------- record_loop() data flow --------------------------
[ Robot ]
V
[ robot.get_observation() ] ---> raw_obs
V
[ robot_observation_processor ] ---> obs_transition
V
.-----( ACTION LOGIC )------------------.
V V
[ From Teleoperator ] [ From Policy ]
| |
| [teleop.get_action] -> raw_action | [predict_action]
| | | |
| V | V
| [teleop_action_processor] | |
| | | |
'---> teleop_transition '---> policy_transition
| |
'-------------------------.-------------'
V
[ robot_action_processor ] --> robot_action_to_send
V
[ robot.send_action() ] -- (Robot Executes)
V
( Transitions are merged & added to Dataset )
V
( Rerun Log / Loop Wait )
"""
@safe_stop_image_writer
def record_loop(
robot: Robot,
@@ -236,30 +194,15 @@ def record_loop(
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
preprocessor: RobotProcessor | None = None,
postprocessor: RobotProcessor | None = None,
control_time_s: int | None = None,
teleop_action_processor: RobotProcessor | None = None, # runs after teleop
robot_action_processor: RobotProcessor | None = None, # runs before robot
robot_observation_processor: RobotProcessor | None = None, # runs after robot
single_task: str | None = None,
display_data: bool = False,
):
teleop_action_processor = teleop_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
)
robot_action_processor = robot_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action
)
robot_observation_processor = robot_observation_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
)
if dataset is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
teleop_arm = teleop_keyboard = None
if isinstance(teleop, list): # For LeKiwi
if isinstance(teleop, list):
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
teleop_arm = next(
(
@@ -275,20 +218,9 @@ def record_loop(
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
)
# Reset policy and processor if they are provided
if policy is not None and preprocessor is not None and postprocessor is not None:
# if policy is given it needs cleaning up
if policy is not None:
policy.reset()
preprocessor.reset()
postprocessor.reset()
# Reset custom pipelines
teleop_action_processor.reset()
robot_action_processor.reset()
robot_observation_processor.reset()
policy_transition = None
teleop_transition = None
obs_transition = None
timestamp = 0
start_episode_t = time.perf_counter()
@@ -299,87 +231,51 @@ def record_loop(
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
observation = robot.get_observation()
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
obs_transition = robot_observation_processor(obs)
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
if dataset is not None:
observation_frame = to_dataset_frame(
obs_transition, dataset.features
) # Convert the observation to the dataset format
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
observation_frame,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
action_names = dataset.features["action"]["names"]
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
policy_transition = {
TransitionKey.ACTION: policy_action,
TransitionKey.COMPLEMENTARY_DATA: {},
}
elif isinstance(teleop, Teleoperator):
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
teleop_transition = teleop_action_processor(act)
elif isinstance(teleop, list):
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
elif policy is None and isinstance(teleop, Teleoperator):
action = teleop.get_action()
elif policy is None and isinstance(teleop, list):
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
arm_action = teleop_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
keyboard_action = teleop_keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
teleop_transition = teleop_action_processor(act)
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
else:
logging.info(
"No policy or teleoperator provided, skipping action generation. "
"This is likely to happen during environment reset."
"No policy or teleoperator provided, skipping action generation."
"This is likely to happen when resetting the environment without a teleop device."
"The robot won't be at its rest position at the start of the next episode."
)
# Still continue to next loop to respect timing
continue
# Applies a pipeline to the action, default is IdentityProcessor
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
if policy_transition is not None:
robot_action_to_send = robot_action_processor(policy_transition)
else:
robot_action_to_send = robot_action_processor(teleop_transition)
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_ = robot.send_action(robot_action_to_send)
# so action actually sent is saved in the dataset.
sent_action = robot.send_action(action)
# Write to dataset
if dataset is not None:
# If to_dataset_frame is provided, use it to merge the transitions.
merged = []
if obs_transition is not None: # The observation from the robot
merged.append(obs_transition)
if teleop_transition is not None: # The action from teleop
merged.append(teleop_transition)
if policy_transition is not None: # The action from policy
merged.append(policy_transition)
frame = to_dataset_frame(
merged if len(merged) > 1 else merged[0], dataset.features
) # Convert the observation to the dataset format
dataset.add_frame(frame, task=single_task)
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
if display_data:
log_rerun_data([obs_transition, teleop_transition or policy_transition])
log_rerun_data(observation, action)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
@@ -405,7 +301,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
@@ -426,23 +321,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
if cfg.policy is not None:
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
if teleop is not None:
@@ -450,49 +332,46 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
listener, events = init_keyboard_listener()
with VideoEncodingManager(dataset):
recorded_episodes = 0
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
recorded_episodes = 0
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
dataset.save_episode()
recorded_episodes += 1
log_say("Stop recording", cfg.play_sounds, blocking=True)
@@ -510,5 +389,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
return dataset
if __name__ == "__main__":
def main():
record()
if __name__ == "__main__":
main()

View File

@@ -14,5 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_so100_follower import SO100FollowerConfig
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
from .so100_follower import SO100Follower
from .so100_follower_end_effector import SO100FollowerEndEffector

View File

@@ -39,3 +39,35 @@ class SO100FollowerConfig(RobotConfig):
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
@RobotConfig.register_subclass("so100_follower_end_effector")
@dataclass
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
"""Configuration for the SO100FollowerEndEffector robot."""
# Path to URDF file for kinematics
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
urdf_path: str | None = None
# End-effector frame name in the URDF
target_frame_name: str = "gripper_frame_link"
# Default bounds for the end-effector position (in meters)
end_effector_bounds: dict[str, list[float]] = field(
default_factory=lambda: {
"min": [-1.0, -1.0, -1.0], # min x, y, z
"max": [1.0, 1.0, 1.0], # max x, y, z
}
)
max_gripper_pos: float = 50
end_effector_step_sizes: dict[str, float] = field(
default_factory=lambda: {
"x": 0.02,
"y": 0.02,
"z": 0.02,
}
)

View File

@@ -1,465 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import numpy as np
from scipy.spatial.transform import Rotation
from lerobot.configs.types import PolicyFeature
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
ObservationProcessor,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.robots.robot import Robot
@ProcessorStepRegistry.register("ee_reference_and_delta")
@dataclass
class EEReferenceAndDelta:
"""
Compute the desired end-effector pose from the target pose and the current pose.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
"""
kinematics: RobotKinematics
end_effector_step_sizes: dict
motor_names: list[str]
use_latched_reference: bool = (
True # If True, latch reference on enable; if False, always use current pose
)
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
_prev_enabled: bool = field(default=False, init=False, repr=False)
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
)
if "reference_joint_positions" in comp:
q = comp["reference_joint_positions"]
else:
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
# Current pose from FK on measured joints
t_curr = self.kinematics.forward_kinematics(q)
enabled = bool(act.pop("action.enabled", 0))
tx = float(act.pop("action.target_x", 0.0))
ty = float(act.pop("action.target_y", 0.0))
tz = float(act.pop("action.target_z", 0.0))
wx = float(act.pop("action.target_wx", 0.0))
wy = float(act.pop("action.target_wy", 0.0))
wz = float(act.pop("action.target_wz", 0.0))
desired = None
if enabled:
ref = t_curr
if self.use_latched_reference:
# Latched reference mode: latch reference at the rising edge
if not self._prev_enabled or self.reference_ee_pose is None:
self.reference_ee_pose = t_curr.copy()
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
delta_p = np.array(
[
tx * self.end_effector_step_sizes["x"],
ty * self.end_effector_step_sizes["y"],
tz * self.end_effector_step_sizes["z"],
],
dtype=float,
)
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
desired = np.eye(4, dtype=float)
desired[:3, :3] = ref[:3, :3] @ r_abs
desired[:3, 3] = ref[:3, 3] + delta_p
self._command_when_disabled = desired.copy()
else:
# While disabled, keep sending the same command to avoid drift.
if self._command_when_disabled is None:
# If we've never had an enabled command yet, freeze current FK pose once.
self._command_when_disabled = t_curr.copy()
desired = self._command_when_disabled.copy()
# Write action fields
pos = desired[:3, 3]
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
act.update(
{
"action.ee.x": float(pos[0]),
"action.ee.y": float(pos[1]),
"action.ee.z": float(pos[2]),
"action.ee.wx": float(tw[0]),
"action.ee.wy": float(tw[1]),
"action.ee.wz": float(tw[2]),
}
)
self._prev_enabled = enabled
transition[TransitionKey.ACTION] = act
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("ee_bounds_and_safety")
@dataclass
class EEBoundsAndSafety(ActionProcessor):
"""
Clip the end-effector pose to the bounds and check for jumps.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
"""
end_effector_bounds: dict
max_ee_step_m: float = 0.05
max_ee_twist_step_rad: float = 0.20
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, act: dict | None) -> dict:
x = act.pop("action.ee.x", None)
y = act.pop("action.ee.y", None)
z = act.pop("action.ee.z", None)
wx = act.pop("action.ee.wx", None)
wy = act.pop("action.ee.wy", None)
wz = act.pop("action.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return act
pos = np.array([x, y, z], dtype=float)
twist = np.array([wx, wy, wz], dtype=float)
# Clip position
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
# Check for jumps in position
if self._last_pos is not None:
dpos = pos - self._last_pos
n = float(np.linalg.norm(dpos))
if n > self.max_ee_step_m and n > 0:
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
self._last_pos = pos
self._last_twist = twist
act.update(
{
"action.ee.x": float(pos[0]),
"action.ee.y": float(pos[1]),
"action.ee.z": float(pos[2]),
"action.ee.wx": float(twist[0]),
"action.ee.wy": float(twist[1]),
"action.ee.wz": float(twist[2]),
}
)
return act
def reset(self):
self._last_pos = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# Because this is last step we specify the dataset features of this step that we want to be stored in the dataset
features["action.ee.x"] = float
features["action.ee.y"] = float
features["action.ee.z"] = float
features["action.ee.wx"] = float
features["action.ee.wy"] = float
features["action.ee.wz"] = float
return features
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
@dataclass
class InverseKinematicsEEToJoints:
"""
Compute the desired joint positions from the desired end-effector pose.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
Output ACTION keys:
{
"action.joint_name_1.pos": float,
"action.joint_name_2.pos": float,
...
"action.joint_name_n.pos": float,
}
"""
kinematics: RobotKinematics
motor_names: list[str]
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
initial_guess_current_joints: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get("action.ee.x", None)
y = act.get("action.ee.y", None)
z = act.get("action.ee.z", None)
wx = act.get("action.ee.wx", None)
wy = act.get("action.ee.wy", None)
wz = act.get("action.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
# Nothing to do; restore what we popped and return
act.update(
{
"action.ee.x": x,
"action.ee.y": y,
"action.ee.z": z,
"action.ee.wx": wx,
"action.ee.wy": wy,
"action.ee.wz": wz,
}
)
transition[TransitionKey.ACTION] = act
return transition
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
)
if self.initial_guess_current_joints: # Use current joints as initial guess
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
else: # Use previous ik solution as initial guess
if self.q_curr is None:
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
# Build desired 4x4 transform from pos + rotvec (twist)
t_des = np.eye(4, dtype=float)
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
t_des[:3, 3] = [x, y, z]
# Compute inverse kinematics
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
self.q_curr = q_target
new_act = dict(act)
for i, name in enumerate(self.motor_names):
if name == "gripper":
new_act["observation.state.gripper.pos"] = float(raw["gripper"])
else:
new_act[f"action.{name}.pos"] = float(q_target[i])
transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints:
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
features["action.ee.x"] = float
features["action.ee.y"] = float
features["action.ee.z"] = float
features["action.ee.wx"] = float
features["action.ee.wy"] = float
features["action.ee.wz"] = float
features["observation.state.gripper.pos"] = float
features["action.gripper.pos"] = float
return features
def reset(self):
self.q_curr = None
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
@dataclass
class GripperVelocityToJoint:
"""
Convert the gripper velocity to a joint velocity.
Input ACTION keys:
{
"action.gripper": float,
}
Output ACTION keys:
{
"action.gripper.pos": float,
}
"""
motor_names: list[str]
speed_factor: float = 20.0
clip_min: float = 0.0
clip_max: float = 100.0
discrete_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION) or {}
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if "action.gripper" not in act:
return transition
if "gripper" not in self.motor_names:
new_act = dict(act)
new_act.pop("action.gripper", None)
transition[TransitionKey.ACTION] = new_act
return transition
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_action = act.get("action.gripper", 1.0)
gripper_action = gripper_action - 1.0
gripper_action *= self.clip_max
act["action.gripper"] = gripper_action
# Get current gripper position from complementary data
raw = comp.get("raw_joint_positions") or {}
curr_pos = float(raw.get("gripper"))
# Compute desired gripper velocity
u = float(act.get("action.gripper", 0.0))
delta = u * float(self.speed_factor)
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
new_act = dict(act)
new_act["action.gripper.pos"] = gripper_pos
new_act.pop("action.gripper", None)
transition[TransitionKey.ACTION] = new_act
obs.update({"observation.state.gripper.pos": curr_pos})
transition[TransitionKey.OBSERVATION] = obs
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
features["observation.state.gripper.pos"] = float
features["action.gripper.pos"] = float
return features
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
@dataclass
class ForwardKinematicsJointsToEE(ObservationProcessor):
"""
Compute the end-effector pose from the joint positions.
Input OBSERVATION keys:
{
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
}
Output OBSERVATION keys:
{
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
}
"""
kinematics: RobotKinematics
motor_names: list[str]
def observation(self, obs: dict | None) -> dict:
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names):
return obs
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
obs.update(
{
"observation.state.ee.x": float(pos[0]),
"observation.state.ee.y": float(pos[1]),
"observation.state.ee.z": float(pos[2]),
"observation.state.ee.wx": float(tw[0]),
"observation.state.ee.wy": float(tw[1]),
"observation.state.ee.wz": float(tw[2]),
}
)
return obs
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz"]:
features[f"observation.state.ee.{k}"] = float
return features
@ProcessorStepRegistry.register("add_robot_observation")
@dataclass
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
"""
Read the robot's current observation and insert it into the transition as complementary data.
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
{ "<motor_name>": <float position>, ... }
"""
robot: Robot
def complementary_data(self, comp: dict | None) -> dict:
comp = {} if comp is None else dict(comp)
obs = self.robot.get_observation()
comp["raw_joint_positions"] = {
k.removesuffix(".pos"): float(v)
for k, v in obs.items()
if isinstance(k, str) and k.endswith(".pos")
}
return comp
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -0,0 +1,200 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
import numpy as np
from lerobot.cameras import make_cameras_from_configs
from lerobot.errors import DeviceNotConnectedError
from lerobot.model.kinematics import RobotKinematics
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.feetech import FeetechMotorsBus
from . import SO100Follower
from .config_so100_follower import SO100FollowerEndEffectorConfig
logger = logging.getLogger(__name__)
class SO100FollowerEndEffector(SO100Follower):
"""
SO100Follower robot with end-effector space control.
This robot inherits from SO100Follower but transforms actions from
end-effector space to joint space before sending them to the motors.
"""
config_class = SO100FollowerEndEffectorConfig
name = "so100_follower_end_effector"
def __init__(self, config: SO100FollowerEndEffectorConfig):
super().__init__(config)
self.bus = FeetechMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
self.config = config
# Initialize the kinematics module for the so100 robot
if self.config.urdf_path is None:
raise ValueError(
"urdf_path must be provided in the configuration for end-effector control. "
"Please set urdf_path in your SO100FollowerEndEffectorConfig."
)
self.kinematics = RobotKinematics(
urdf_path=self.config.urdf_path,
target_frame_name=self.config.target_frame_name,
)
# Store the bounds for end-effector position
self.end_effector_bounds = self.config.end_effector_bounds
self.current_ee_pos = None
self.current_joint_pos = None
@property
def action_features(self) -> dict[str, Any]:
"""
Define action features for end-effector control.
Returns dictionary with dtype, shape, and names.
"""
return {
"dtype": "float32",
"shape": (4,),
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""
Transform action from end-effector space to joint space and send to motors.
Args:
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
or a numpy array with [delta_x, delta_y, delta_z]
Returns:
The joint-space action that was sent to the motors
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Convert action to numpy array if not already
if isinstance(action, dict):
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
delta_ee = np.array(
[
action["delta_x"] * self.config.end_effector_step_sizes["x"],
action["delta_y"] * self.config.end_effector_step_sizes["y"],
action["delta_z"] * self.config.end_effector_step_sizes["z"],
],
dtype=np.float32,
)
if "gripper" not in action:
action["gripper"] = [1.0]
action = np.append(delta_ee, action["gripper"])
else:
logger.warning(
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
)
action = np.zeros(4, dtype=np.float32)
if self.current_joint_pos is None:
# Read current joint positions
current_joint_pos = self.bus.sync_read("Present_Position")
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
# Calculate current end-effector position using forward kinematics
if self.current_ee_pos is None:
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
# Set desired end-effector position by adding delta
desired_ee_pos = np.eye(4)
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
# Add delta to position and clip to bounds
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
if self.end_effector_bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3],
self.end_effector_bounds["min"],
self.end_effector_bounds["max"],
)
# Compute inverse kinematics to get joint positions
target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
self.current_joint_pos, desired_ee_pos
)
# Create joint space action dictionary
joint_action = {
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
}
# Handle gripper separately if included in action
# Gripper delta action is in the range 0 - 2,
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
joint_action["gripper.pos"] = np.clip(
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
5,
self.config.max_gripper_pos,
)
self.current_ee_pos = desired_ee_pos.copy()
self.current_joint_pos = target_joint_values_in_degrees.copy()
self.current_joint_pos[-1] = joint_action["gripper.pos"]
# Send joint space action to parent class
return super().send_action(joint_action)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def reset(self):
self.current_ee_pos = None
self.current_joint_pos = None

View File

@@ -69,7 +69,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
raise ValueError(config.type)
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
def ensure_safe_goal_position(
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
) -> dict[str, float]:

View File

@@ -115,11 +115,11 @@ If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you c
echo ${HF_USER}/aloha_test
```
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun):
```bash
python -m lerobot.scripts.visualize_dataset_html \
--repo-id ${HF_USER}/aloha_test
python -m lerobot.scripts.visualize_dataset \
--repo-id ${HF_USER}/aloha_test --episode 0
```
## Replay an episode

View File

@@ -62,16 +62,9 @@ from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.processor.pipeline import TransitionKey
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl.gym_manipulator import (
create_transition,
make_processors,
make_robot_env,
step_env_and_process_transition,
)
from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
@@ -243,8 +236,7 @@ def act_with_policy(
logging.info("make_env online")
online_env, teleop_device = make_robot_env(cfg=cfg.env)
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
online_env = make_robot_env(cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
@@ -265,12 +257,6 @@ def act_with_policy(
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -288,61 +274,45 @@ def act_with_policy(
logging.info("[ACTOR] Shutting down act_with_policy")
return
observation = transition[TransitionKey.OBSERVATION]
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = policy.select_action(batch=obs)
policy_fps = policy_timer.fps_last
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
else:
action = online_env.action_space.sample()
# Use the new step function
new_transition = step_env_and_process_transition(
env=online_env,
transition=transition,
action=action,
env_processor=env_processor,
action_processor=action_processor,
)
# Extract values from processed transition
next_observation = new_transition[TransitionKey.OBSERVATION]
executed_action = new_transition[TransitionKey.ACTION]
reward = new_transition[TransitionKey.REWARD]
done = new_transition.get(TransitionKey.DONE, False)
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
next_obs, reward, done, truncated, info = online_env.step(action)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
complementary_info = {
"discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
),
}
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
Transition(
state=observation,
action=executed_action,
state=obs,
action=action,
reward=reward,
next_state=next_observation,
next_state=next_obs,
done=done,
truncated=truncated,
complementary_info=complementary_info,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info,
)
)
# Update transition for next iteration
transition = new_transition
# assign obs to the next obs and continue the rollout
obs = next_obs
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
@@ -377,20 +347,12 @@ def act_with_policy(
)
)
# Reset intervention counters and environment
# Reset intervention counters
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time

View File

@@ -226,7 +226,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
value = value.unsqueeze(0)
new_frame[key] = value
new_dataset.add_frame(new_frame, task=task)
new_frame["task"] = task
new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode

File diff suppressed because it is too large Load Diff

View File

@@ -75,7 +75,6 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl import learner_service
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
@@ -1049,8 +1048,10 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_cached_image_features(
next_observations, normalize=True
)
return observation_features, next_observation_features
@@ -1175,7 +1176,7 @@ def process_transitions(
# Add to offline buffer if it's an intervention
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
TeleopEvents.IS_INTERVENTION
"is_intervention"
):
offline_replay_buffer.add(**transition)

View File

@@ -302,11 +302,6 @@ class RobotClient:
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:

View File

@@ -26,13 +26,12 @@ from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.scripts.eval import eval_policy
@@ -141,9 +140,6 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy,
ds_meta=dataset.meta,
)
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@@ -153,10 +149,6 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
postprocessor.from_pretrained(
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -174,7 +166,8 @@ def train(cfg: TrainPipelineConfig):
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.episode_data_index,
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
@@ -211,9 +204,12 @@ def train(cfg: TrainPipelineConfig):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(
train_tracker,
policy,
@@ -245,9 +241,7 @@ def train(cfg: TrainPipelineConfig):
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
@@ -291,8 +285,6 @@ def train(cfg: TrainPipelineConfig):
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
postprocessor.push_to_hub(cfg.policy.repo_id)
def main():

View File

@@ -1,311 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any, Callable
import accelerate
import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
has_method,
init_logging,
is_launched_with_accelerate,
)
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy(
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
accelerator: Callable = None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
policy.train()
loss, output_dict = policy.forward(batch)
accelerator.backward(loss)
accelerator.unscale_gradients(optimizer=optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: Callable):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
if accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_seed(cfg.seed, accelerator=accelerator)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
device=device,
ds_meta=dataset.meta,
)
policy.to(device)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
if accelerator.is_main_process:
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
train_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
accelerator=accelerator,
)
if accelerator.is_main_process:
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
accelerator=accelerator,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and accelerator.is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps and accelerator.is_main_process
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and accelerator.is_main_process
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(
checkpoint_dir,
step,
cfg,
accelerator.unwrap_model(policy),
optimizer,
lr_scheduler,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
accelerator.wait_for_everyone()
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad():
eval_info = eval_policy(
env=eval_env,
policy=accelerator.unwrap_model(policy),
n_episodes=cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
accelerator=None,
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:
eval_env.close()
if not accelerator or accelerator.is_main_process:
logging.info("End of training")
if __name__ == "__main__":
init_logging()
# We set step_scheduler_with_optimizer False to prevent accelerate from
# adjusting the lr_scheduler steps based on the num_processes
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
train(accelerator=accelerator)

View File

@@ -79,8 +79,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
@@ -283,7 +283,7 @@ def main():
tolerance_s = kwargs.pop("tolerance_s")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
visualize_dataset(dataset, **vars(args))

View File

@@ -1,482 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Example of usage:
- Visualize data stored on a local machine:
```bash
local$ python -m lerobot.scripts.visualize_dataset_html \
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```bash
distant$ python -m lerobot.scripts.visualize_dataset_html \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```bash
python -m lerobot.scripts.visualize_dataset_html \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
```
"""
import argparse
import csv
import json
import logging
import re
import shutil
import tempfile
from io import StringIO
from pathlib import Path
import numpy as np
import pandas as pd
import requests
from flask import Flask, redirect, render_template, request, url_for
from lerobot import available_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import IterableNamespace
from lerobot.utils.utils import init_logging
def run_server(
dataset: LeRobotDataset | IterableNamespace | None,
episodes: list[int] | None,
host: str,
port: str,
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
def hommepage(dataset=dataset):
if dataset:
dataset_namespace, dataset_name = dataset.repo_id.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=0,
)
)
dataset_param, episode_param = None, None
all_params = request.args
if "dataset" in all_params:
dataset_param = all_params["dataset"]
if "episode" in all_params:
episode_param = int(all_params["episode"])
if dataset_param:
dataset_namespace, dataset_name = dataset_param.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=episode_param if episode_param is not None else 0,
)
)
featured_datasets = [
"lerobot/aloha_static_cups_open",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/taco_play",
]
return render_template(
"visualize_dataset_homepage.html",
featured_datasets=featured_datasets,
lerobot_datasets=available_datasets,
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
def show_first_episode(dataset_namespace, dataset_name):
first_episode_id = 0
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=first_episode_id,
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
dataset = get_dataset_info(repo_id)
except FileNotFoundError:
return (
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
400,
)
dataset_version = (
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
major_version = int(match.group(1))
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
if isinstance(dataset, LeRobotDataset)
else dataset.total_frames,
"num_episodes": dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes,
"fps": dataset.fps,
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
]
videos_info = [
{
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
"filename": video_path.parent.name,
}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.video_path.format(
episode_chunk=int(episode_id) // dataset.chunks_size,
video_key=video_key,
episode_index=episode_id,
),
"filename": video_key,
}
for video_key in video_keys
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
)
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
ignored_columns=ignored_columns,
)
app.run(host=host, port=port)
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp")
ignored_columns = []
for column_name in selected_columns:
shape = dataset.features[column_name]["shape"]
shape_dim = len(shape)
if shape_dim > 1:
selected_columns.remove(column_name)
ignored_columns.append(column_name)
# init header of csv with state and action names
header = ["timestamp"]
for column_name in selected_columns:
dim_state = (
dataset.meta.shapes[column_name][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
header += column_names
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
.with_format("pandas")
)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
# Convert data to CSV string
csv_buffer = StringIO()
csv_writer = csv.writer(csv_buffer)
# Write header
csv_writer.writerow(header)
# Write data rows
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns, ignored_columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.meta.video_keys
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
return IterableNamespace(dataset_info)
def visualize_dataset_html(
dataset: LeRobotDataset | None,
episodes: list[int] | None = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
) -> Path | None:
init_logging()
template_dir = Path(__file__).resolve().parent.parent / "templates"
if output_dir is None:
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
if dataset is None:
if serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
static_folder=static_dir,
template_folder=template_dir,
)
else:
# Create a simlink from the dataset video folder containing mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix())
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
visualize_dataset_html(dataset, **vars(args))
if __name__ == "__main__":
main()

View File

@@ -109,7 +109,7 @@ def teleop_loop(
action = teleop.get_action()
if display_data:
observation = robot.get_observation()
log_rerun_data(observation=observation, action=action)
log_rerun_data(observation, action)
robot.send_action(action)
dt_s = time.perf_counter() - loop_start

Some files were not shown because too many files have changed in this diff Show More