Compare commits

..

37 Commits

Author SHA1 Message Date
Jade Choghari
32fc5504cc add various experiments for wavelet 2026-02-13 10:27:02 +00:00
Steven Palma
fc8a388a25 feat(cameras): make backend configurable to the CLI (#2945)
* feat(cameras): make backend configurable to the CLI

* chore(cameras): address feedback

* feat(Enum error messages): adding better instanciation error messages for Enum classes

* chore(Enum error messages): propagating Enum error messages to all camera classes

* chore(comments): removing superfluous comments

* chore(format): applying ruff checks

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
2026-02-11 13:57:25 +01:00
Steven Palma
3c84d271d5 fix(motors): use decorator to fix precommit (#2951) 2026-02-10 18:40:50 +01:00
Steven Palma
1ba3975020 chore: use is_connected decorators (#2948)
* chore: use is_connected decorators

* chore(robots): add is_connected to bi setups too
2026-02-10 17:49:30 +01:00
Steven Palma
35363c5798 chore(linter): ensure motors module passes MyPy type checks (#2939)
* fix: ensure motors module passes MyPy type checks

This commit fixes 62 mypy type errors in the motors module by:

- Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead,
  GroupSyncWrite) to use class-level attribute declarations instead of
  __init__ body declarations
- Adding missing `broadcastPing` method to PacketHandler Protocol
- Fixing return type annotations (e.g., `_get_motor_model` returns str, not int)
- Fixing parameter types to use `Sequence` for covariant list parameters
- Fixing `Mapping` for covariant dict value types in `_normalize`
- Updating method signatures to be consistent across parent and child classes
  (disable_torque, enable_torque, _get_half_turn_homings)
- Adding explicit `int()` casts for MotorCalibration arguments
- Adding explicit `return None` for functions returning Optional types
- Adding type annotations for variables like `data_list: dict[int, int]`
- Using `# type: ignore[method-assign]` for intentional monkeypatch
- Fixing variable references (using `self.groups` instead of `groups`)

Fixes #1723

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* chore(style): pre-commit after main merge

* chore(linter): solve comments

* chore(linter): apply pre-commit fixes to damiao

* chore(linter): more fixes to damiao

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-10 17:35:39 +01:00
whats2000
778db19a17 [Bug Fix] fix(ci): prevent runner group error on fork pushes (#2911)
* fix(ci): prevent runner group error on fork pushes

Add repository check to unbound_deps_tests workflow to ensure
aws-general-8-plus runner group is only used on main repository,
preventing 'Required runner group not found' errors on forks.

* fix(ci): use gating job to prevent runner allocation on forks

The previous approach failed because GitHub evaluates runs-on before if conditions.
Now using a check-repo job that runs on ubuntu-latest first, and all jobs with
special runners depend on it and check its output before being scheduled.

* fix(ci): add gating job to full_tests to prevent runner allocation on forks

Apply the same gating pattern used in unbound_deps_tests to full_tests.yml
to prevent GitHub from trying to allocate custom runners when workflows
run on forks. The check-repo job runs first on ubuntu-latest and all jobs
with custom runners depend on it and check its output.

* fix(ci): add repository check to unbound_deps_tests workflow

Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker job to prevent runner group access errors on forks, matching the pattern used in nightly.yml

* fix(ci): add repository check to full_tests workflow

Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker and gpu-tests jobs to prevent runner group access errors on forks

* refactor(ci): remove redundant check from gpu-tests job

gpu-tests depends on build-and-push-docker via needs, so it will automatically skip when the parent job is skipped

* refactor(ci): remove unnecessary fork check from full-tests job

full-tests runs on ubuntu-latest which is available to all forks, no need to restrict it

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:21:40 +01:00
Jai Kumaar Ratadia
d2d01399d6 docs: clarify installation steps are sequential, not optional (#2925)
* docs: clarify installation steps are sequential, not optional

Add intro paragraph noting conda is one path (not the only one) and
number the three sections as steps so readers understand miniforge and
environment setup are prerequisites, not independent choices.

* Update installation guide link for LeRobot

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

* Fix link formatting in installation guide again

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

---------

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:18:32 +01:00
Aoqun Jin
5eba4ce6f4 Change LIBERO init_state_id when reset. (#2899)
* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* pre-commit run

---------

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 16:39:17 +03:00
Stepan Feduniak
cca0296cd6 fix(pipeline): use FeatureType for STATE features in Libero processor (#2888)
* fix the types

* pre-commit

---------

Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:55:11 +03:00
Steven Palma
489cb7b6b9 fix(scripts): correct can import check (#2937) 2026-02-09 16:58:32 +01:00
Reece O'Mahoney
e14bdf57d0 Convert tensors to scalars (#2903)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-09 14:46:12 +01:00
Reece O'Mahoney
97e7e0f9ed feat(datasets): improve image transform support (#2885)
* improve image transform support

* add tests

* Add stricter transform check and extra test

* improve subclass check
2026-02-05 15:39:58 +01:00
jwang078
0f39248445 Small docstring fix in diffusion configuration (#2847) 2026-02-03 19:19:00 +01:00
Iori Yanokura
a6370dd783 fix(wandb): truncate init tags to 64-character limit (#995) 2026-02-03 14:17:04 +01:00
Michel Aractingi
14a15f90e7 Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873)
* fix(RL) add missing config arguments

* respond to copilot review

* fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic.

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
2026-02-02 22:14:03 +01:00
Hirokazu Ishida
9c24a09665 docs: update document in response to Simplify configs PR (#1596)
* docs: update document input/output_shapes -> input/output_features

* fix inconsistent quote (suggested by copilot reviewer)

* docs: shapes => PolicyFeature

* docs: relfect normalization_mapping and remove outdated
2026-02-02 20:05:58 +01:00
Jade Choghari
b18cef2e26 feat(dataset): add subtask support (#2860)
* add subtask

* remove folder

* add docs

* update doc

* add testing

* update test

* update constant naming + doc

* more docs
2026-01-30 19:29:37 +01:00
Caroline Pascal
5c6182176f fix(find zmq): adding a clearer not implemented warning for the ZMQ find_cameras method (#2879)
Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
2026-01-30 16:58:13 +01:00
Caroline Pascal
55c0471db9 docs(cameras): revising and improving docs on cameras (#2878)
* docs(cameras): revising and improving docs on cameras

* resolving copilot comments
2026-01-30 16:57:56 +01:00
Michel Aractingi
ec04b7ce3a Feat(dataset_tools.py) Add modify tasks tool (#2875)
* feat(datasets): add modify_tasks function for in-place task editing

Add a new utility function to modify tasks in LeRobotDataset in-place.
This allows users to:
- Set a single task for all episodes
- Set specific tasks for individual episodes
- Combine a default task with per-episode overrides

* feat(edit-dataset): add CLI support for modify_tasks operation

Integrate the modify_tasks function into lerobot_edit_dataset CLI.
Users can now modify dataset tasks via command line:
Supports setting a default task, per-episode tasks, or both combined.

* test(datasets): add tests for modify_tasks function

Add comprehensive test coverage for the modify_tasks utility:
- Single task for all episodes
- Episode-specific task assignment
- Default task with per-episode overrides
- Error handling for missing/invalid arguments
- Verification of task_index correctness
- In-place modification behavior
- Metadata preservation

* respond to copilot review
2026-01-30 13:19:42 +01:00
Michel Aractingi
04cbf669cf fix(sac): make temperature a property to fix checkpoint resume bug (#2877)
* fix(sac): make temperature a property to fix checkpoint resume bug

Temperature was stored as a plain float and not restored after loading
a checkpoint, causing incorrect loss computations until update_temperature()
was called. Changed to a property that always computes from log_alpha,
ensuring correct behavior after checkpoint loading.

* simplify docstrings
2026-01-30 12:23:22 +01:00
Steven Palma
3409ef0dc2 refactor(cameras): cameras API extension (#2808)
* feat(cameras): add new read_latest() method

* fix(cameras): fix threading bug + clear state

* refactor(cameras): multiple improvements

* feat(camera): add context manager to camera base class

* chore(camera): slight modifications to opencv

* test(cameras): update opencv tests according to the changes

* refactor(cameras): reflect desing changes to realsense + deal with depth

* test(cameras): fix realsense tests accordingly to new changes

* refactor(cameras): update reachymini and zmq accordingly

* chore: wrap resource sensitive examples into a try/finally

* test(cameras): add test for new read_latest

* test(cameras): fix problem with image artifact in opencv tests

* test(cameras): fix test_read_latest_high_frequency expectations

* Apply suggestions from code review 1

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(cameras): address feedback

* feat(cameras): add max_age_ms check in read_latest

* test(cameras): fix read_latest tests

* chore(redundancies): removing redundancies in Reachy 2 camera class

* fix(warmup): replacing the arbitrary time.sleep in by an actual warmup in the RealSense camera class

* chore(format): formatting latest changes

* chore(warning): adding a "to be implemented" warning for read_latest() in Camera base class

* chore(warning): making read_latest() warning message shorter and clearer

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-01-29 11:07:47 +01:00
Steven Palma
4483184875 feat(robots): add bi manual openarm follower and leader (#2835)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* feat(robot): add openarm leader

Co-authored-by: Pepijn <pepijn@huggingface.co>

* feat(robot): add openarm follower

Co-authored-by: Pepijn <pepijn@huggingface.co>

* refactor(robot): remove mechanical compensations and double arm assumption + rename

* chore(robots): remove left arm references

* refactor(teleop): multiple improvements to leader

* refactor(teleop): multiple improvements to leader

* feat(robots): add open arm to util CLI

* chore(robot): add alias openarm

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* fix(robots): open arm mirrored config for joint limits

* chore(motors): update position_kd gain values

* chore(robots): set to 0 if openarm is calibrated at connect time

* chore(robots): remove macos in open arm as can doesn't support it

* chore(robots): update for motor_type_str in Motor class

* chore(robots): no default value for can port in open arms

* feat(robots): add bi manual openarm follower and leader

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

* remove comment

* Add openarms docs

* format

* update purchase link

* can to none if nit availabl;e

* add canfd option in bus

* make handshake logic similar to lerobot-can

* type hint

* type check

* add temp teleop test

* remove script

* mock class

* mock class

* ignore linter

* pre-commit

* Add command for bimanual openarm

* fix import

* fix import leader

* fix import draccus

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-01-28 17:25:57 +01:00
Martino Russi
149628dfd5 add g1 teleoperation (#2791)
* add gravity compensation

* add g1 teleoperation

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2026-01-28 15:17:38 +01:00
Steven Palma
bf337e716d feat(robots): add OpenArm robot & teleoperator (#2795)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* feat(robot): add openarm leader

Co-authored-by: Pepijn <pepijn@huggingface.co>

* feat(robot): add openarm follower

Co-authored-by: Pepijn <pepijn@huggingface.co>

* refactor(robot): remove mechanical compensations and double arm assumption + rename

* chore(robots): remove left arm references

* refactor(teleop): multiple improvements to leader

* refactor(teleop): multiple improvements to leader

* feat(robots): add open arm to util CLI

* chore(robot): add alias openarm

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* fix(robots): open arm mirrored config for joint limits

* chore(motors): update position_kd gain values

* chore(robots): set to 0 if openarm is calibrated at connect time

* chore(robots): remove macos in open arm as can doesn't support it

* chore(robots): update for motor_type_str in Motor class

* chore(robots): no default value for can port in open arms

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

* remove comment

* Add openarms docs

* format

* update purchase link

* can to none if nit availabl;e

* add canfd option in bus

* make handshake logic similar to lerobot-can

* type hint

* type check

* add temp teleop test

* remove script

* mock class

* ignore linter

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-01-28 14:28:51 +01:00
Michel Aractingi
736b43f3cf Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861)
* Fix aggeregation of datasets when subdatasets are already a result of a previous merge

* docstring

* respond to copilot review + add regression test

* Remove unnecessary int conversion for indicies
2026-01-28 13:31:27 +01:00
Reece O'Mahoney
f6b1c39b78 docs: update libero (#2857)
* update libero docs

* Update docs/source/libero.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-27 15:31:53 +01:00
Pepijn
0c0c171d35 Add robot images to docs (#2862)
* Add robot images to docs

* increase img size

* remove img so100
2026-01-27 13:33:45 +01:00
Steven Palma
9cfb5ce546 feat(motors): add damiao motors & can bus (#2788)
* fix(motors): cleanup imports + fix signatures

* feat(motors): add damiao canbus + multiple fixes

* fix(motors): address comments -> last_state + different gains + sleep

* refactor(motors): reduce duplicated code + adressed some comments in the PR

* chore(motors): better timeouts

* tests(motors): damiao test and imports

* chore(deps): fix space

* Apply suggestions from code review

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(motors): remove normalization tables damiao

* fix(motors): imports and signatures

* feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id

* chore(motors): remove normalize from base motor class and damaio

* tests(motors): remove bad tests (to be replaced)

* chore(motors): updated import check

* use constant for kp and kd range and check responses in mit_control_batch()

* Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command

* precommit format

* supress bandit as these are intentional cli commands

* fix setup-can

* add test

* skip test in ci

* nit precommit

* update doc example

* dont import can for tests

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2026-01-26 17:53:25 +01:00
Reece O'Mahoney
366bef915c add task ids to libero env cfg (#2842) 2026-01-26 17:26:49 +01:00
Woojin Wie
9e10eb4a77 fix(robots): update gripper configuration and calibration settings for OMX (#2815) 2026-01-25 22:29:37 +01:00
Steven Palma
6d34a986de feat(ci): trigger manually documentation release version (#2841) 2026-01-22 12:26:17 +01:00
Steven Palma
961277d86e chore(dependencies): Bump lerobot to 0.4.4 (#2840) 2026-01-22 12:24:12 +01:00
Steven Palma
0b067df57d feat(robots): add context managers (#2828) 2026-01-20 18:02:38 +01:00
Tommy in Tongji
9ca680dce2 Update README.md (#2827)
Add Chinese doc link.

Signed-off-by: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com>
2026-01-20 17:54:24 +01:00
sato_shinji
9919b16b36 fix: ensure action tensors are moved to client_device in async training (#2792)
* feat(async_inference): server always sends CPU tensors, client handles device conversion

* fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py

* update the import of robot_client

---------

Co-authored-by: Sato shinji <wwwsatoshinji@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: KB <kevin-brian.n-diaye@epita.fr>
2026-01-20 15:17:38 +01:00
Caroline Pascal
d36dfcdf71 fix(discord link): fixing discord link in CONTRIBUTING.md (#2826)
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-01-20 15:00:45 +01:00
179 changed files with 9096 additions and 6962 deletions

View File

@@ -18,6 +18,11 @@ name: Documentation
on:
# Allows running this workflow manually from the Actions tab
workflow_dispatch:
inputs:
version:
description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build'
required: false
type: string
# Triggers the workflow on push events to main for the docs folder
push:
@@ -54,7 +59,13 @@ jobs:
with:
commit_sha: ${{ github.sha }}
package: lerobot
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
additional_args: >-
--not_python_module
${{
(github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) ||
(inputs.version != '' && format('--version {0}', inputs.version)) ||
''
}}
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@@ -101,9 +101,11 @@ jobs:
runs-on:
group: aws-general-8-plus
if: |
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
github.repository == 'huggingface/lerobot' && (
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
)
outputs:
image_tag: ${{ steps.set_tag.outputs.image_tag }}
env:

View File

@@ -91,6 +91,7 @@ jobs:
name: Build and Push Docker
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
env:

View File

@@ -14,7 +14,7 @@ You can contribute in many ways:
- **Documentation:** Improve examples, guides, and docstrings.
- **Feedback:** Submit tickets related to bugs or desired new features.
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f).
## Development Setup

View File

@@ -128,6 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
## Resources
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players.
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.

View File

@@ -1,219 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from soundfile import read
from lerobot.microphones.configs import MicrophoneConfig
from lerobot.microphones.portaudio import PortAudioMicrophone, PortAudioMicrophoneConfig
from lerobot.microphones.utils import (
async_microphones_start_recording,
async_microphones_stop_recording,
make_microphones_from_configs,
)
from lerobot.utils.robot_utils import (
precise_sleep,
)
def main(
microphones_configs: dict[str, MicrophoneConfig],
audio_chunks_number: int,
audio_chunks_duration: float,
repetitions: int,
multiprocessing: bool = False,
):
recording_dir = Path("outputs/audio_benchmark")
recording_dir.mkdir(parents=True, exist_ok=True)
# Create microphones
microphones = make_microphones_from_configs(microphones_configs)
# Connect microphones
for microphone in microphones.values():
microphone.connect()
all_audio_chunks = []
for i in range(repetitions):
print(f"Repetition {i + 1}/{repetitions}...")
# Create audio chunks
audio_chunks = {}
for microphone_key in microphones:
audio_chunks.update({microphone_key: []})
# Start recording
async_microphones_start_recording(
microphones,
output_files=[
recording_dir / f"{microphone_key}_recording_{i}.wav" for microphone_key in microphones
],
multiprocessing=multiprocessing,
)
# Record audio chunks
for j in range(audio_chunks_number):
precise_sleep(audio_chunks_duration)
for microphone_key, microphone in microphones.items():
audio_chunk = microphone.read()
print(f"{microphone_key} - repetition {i} - chunk {j} - samples {audio_chunk.shape[0]}")
audio_chunks[microphone_key].append(audio_chunk)
# Stop recording
async_microphones_stop_recording(microphones)
for microphone_key in microphones:
audio_chunks[microphone_key] = np.concatenate(audio_chunks[microphone_key], axis=0)
all_audio_chunks.append(audio_chunks)
# Disconnect microphones
for microphone in microphones.values():
microphone.disconnect()
# Compute statistics
cmap = plt.get_cmap("tab10")
_, ax = plt.subplots(nrows=repetitions, ncols=len(microphones))
chunk_length = np.zeros((repetitions, len(microphones)))
record_length = np.zeros((repetitions, len(microphones)))
for i in range(repetitions):
for j, (microphone_key, microphone) in enumerate(microphones.items()):
# Get recorded audio chunks
recorded_audio_chunks = all_audio_chunks[i][microphone_key]
# Load recorded file
recorded_data, _ = read(recording_dir / f"{microphone_key}_recording_{i}.wav")
if recorded_data.ndim == 1:
recorded_data = np.expand_dims(recorded_data, axis=1)
record_length[i, j] = recorded_data.shape[0]
chunk_length[i, j] = recorded_audio_chunks.shape[0]
for k, (chunk_data, record_data) in enumerate(
zip(recorded_audio_chunks.T, recorded_data.T, strict=False)
):
# Plot audio chunks and recorded data
ax[i, j].plot(
np.arange(0, len(chunk_data)) / microphone.sample_rate,
chunk_data,
label=f"audio chunks - channel {k}",
color=cmap(2 * k),
)
ax[i, j].plot(
np.arange(0, len(record_data)) / microphone.sample_rate,
record_data,
label=f"recorded data - channel {k}",
linestyle="dashed",
color=cmap(2 * k + 1),
)
# Plot absolute difference (errors should be located at the end of the recordings)
if recorded_data.shape[0] - recorded_audio_chunks.shape[0] > 0:
chunk_data = np.append(
chunk_data, np.zeros(int(recorded_data.shape[0] - recorded_audio_chunks.shape[0]))
)
else:
record_data = np.append(
record_data, np.zeros(int(-recorded_data.shape[0] + recorded_audio_chunks.shape[0]))
)
ax[i, j].plot(
np.arange(0, len(record_data)) / microphone.sample_rate,
np.abs(chunk_data - record_data),
label=f"differences - channel {k}",
color="red",
linestyle="dotted",
)
ax[i, j].set_title(f"{microphone_key} - repetition {i}")
ax[i, j].legend()
plt.show()
# Print statistics
differences = record_length - chunk_length
for i, (microphone_key, microphone) in enumerate(microphones.items()):
print(
f"Average recorded duration for {microphone_key} : {np.mean(record_length[:, i]) / microphone.sample_rate:.3f} seconds"
)
print(
f"Average chunk duration for {microphone_key} : {np.mean(chunk_length[:, i]) / microphone.sample_rate:.3f} seconds"
)
print(f"Average difference for {microphone_key} : {np.mean(differences[:, i]):.3f} samples")
print(
f"Average difference for {microphone_key} : {np.mean(differences[:, i]) / microphone.sample_rate:.3f} seconds"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--microphones_indices",
type=int,
nargs="+",
default=[microphone["index"] for microphone in PortAudioMicrophone.find_microphones()],
)
parser.add_argument(
"--microphones_sample_rate",
type=float,
nargs="+",
default=[None] * len(PortAudioMicrophone.find_microphones()),
)
parser.add_argument(
"--microphones_channels",
type=int,
nargs="+",
default=[None] * len(PortAudioMicrophone.find_microphones()),
)
parser.add_argument("--audio_chunks_number", type=int, default=2)
parser.add_argument(
"--audio_chunks_duration",
type=float,
default=1.0,
)
parser.add_argument(
"--repetitions",
type=int,
default=2,
)
parser.add_argument(
"--multiprocessing",
action="store_true",
)
args = vars(parser.parse_args())
args["microphones_configs"] = {}
for index, sample_rate, channels in zip(
args["microphones_indices"],
args["microphones_sample_rate"],
args["microphones_channels"],
strict=False,
):
microphone_config = PortAudioMicrophoneConfig(
microphone_index=index,
sample_rate=sample_rate,
channels=channels,
)
args["microphones_configs"].update({f"microphone_{index}": microphone_config})
args.pop("microphones_indices")
args.pop("microphones_sample_rate")
args.pop("microphones_channels")
main(**args)

View File

@@ -1,136 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import numpy as np
import soundfile as sf
from lerobot.microphones.anyskin import AnyskinSensorConfig
from lerobot.microphones.configs import MicrophoneConfig
from lerobot.microphones.utils import (
async_microphones_start_recording,
async_microphones_stop_recording,
make_microphones_from_configs,
)
from lerobot.utils.robot_utils import (
precise_sleep,
)
def main(
sensors_configs: dict[str, MicrophoneConfig],
multiprocessing: bool = False,
):
recording_dir = Path("outputs/tactile_benchmark")
recording_dir.mkdir(parents=True, exist_ok=True)
# Create microphones
sensors = make_microphones_from_configs(sensors_configs)
# Connect microphones
for sensor in sensors.values():
sensor.connect()
# Create audio chunks
data_chunks = {}
for sensor_key in sensors:
data_chunks.update({sensor_key: []})
# Start recording
async_microphones_start_recording(
sensors,
output_files=[recording_dir / f"{sensor_key}_recording.wav" for sensor_key in sensors],
multiprocessing=multiprocessing,
)
# Record audio chunks
precise_sleep(10.0)
for sensor_key, sensor in sensors.items():
data_chunk = sensor.read()
print(f"{sensor_key} - samples {data_chunk.shape[0]}")
data_chunks[sensor_key].append(data_chunk)
# Stop recording
async_microphones_stop_recording(sensors)
for sensor_key in sensors:
data_chunks[sensor_key] = np.concatenate(data_chunks[sensor_key], axis=0)
# Disconnect microphones
for sensor in sensors.values():
sensor.disconnect()
for sensor_key in sensors:
data, sample_rate = sf.read(recording_dir / f"{sensor_key}_recording.wav")
print(f"{sensor_key} - samples {data.shape[0]}")
print(f"{sensor_key} - sample rate {sample_rate}")
print(f"{sensor_key} - data {data}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--sensors_ports",
type=str,
nargs="+",
)
parser.add_argument(
"--sensors_baud_rate",
type=int,
nargs="+",
)
parser.add_argument(
"--sensors_sample_rate",
type=int,
nargs="+",
)
parser.add_argument(
"--sensors_channels",
type=int,
nargs="+",
)
parser.add_argument(
"--multiprocessing",
action="store_true",
)
args = vars(parser.parse_args())
args["sensors_configs"] = {}
for port, baud_rate, sample_rate, channels in zip(
args["sensors_ports"],
args["sensors_baud_rate"],
args["sensors_sample_rate"],
args["sensors_channels"],
strict=False,
):
channels = [1, 2, 3, 4, 5]
sensor_config = AnyskinSensorConfig(
sensor_port=port,
baud_rate=baud_rate,
sample_rate=sample_rate,
channels=channels,
)
args["sensors_configs"].update({f"sensor_{port}": sensor_config})
args.pop("sensors_ports")
args.pop("sensors_baud_rate")
args.pop("sensors_sample_rate")
args.pop("sensors_channels")
main(**args)

134
benchmarks/tokens/README.md Normal file
View File

@@ -0,0 +1,134 @@
# Action tokenizer benchmark
## Questions
What is the trade-off between:
- **Compression**: how many tokens are needed to represent an action chunk (e.g. horizon × action_dim floats)?
- **Reconstruction quality**: how well does encode-then-decode preserve the original actions?
- **Speed**: how long does encoding and decoding take per chunk?
How to choose an action tokenizer?
- Which tokenizer architecture (e.g. dct + BPE, DCT + BPE)?
- Which **action horizon** and **encoded dimensions** to use?
- Which **normalization** (QUANTILES, MEAN_STD, MIN_MAX) and **delta transform** (relative vs absolute actions)?
- How do reconstruction error and compression ratio vary across datasets and tokenizer settings?
This benchmark loads action chunks from a LeRobot dataset using the same pipeline as `lerobot-train-tokenizer`, runs a trained action tokenizer in encode/decode mode, and reports reconstruction error, compression stats, and timing. Results are saved as JSON under `outputs/` for comparison and analysis.
## Variables
**Dataset & chunking**
- **repo_id**: LeRobot dataset (e.g. `lerobot/pusht`). Action statistics and normalization are taken from the dataset metadata when available.
- **action_horizon**: Number of future steps per action chunk (must match the tokenizers training).
- **encoded_dims**: Dimension ranges to encode (e.g. `0:6` or `0:6,7:14`). Must match the tokenizer.
- **max_episodes**: Cap on episodes to load (default: all).
- **sample_fraction**: Fraction of chunks to sample per episode (default `0.2`) to keep runtime manageable.
**Transform & normalization**
- **normalization_mode**: `IDENTITY`, `MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`. Should match the tokenizers training.
- **delta_dims**: Comma-separated dimension indices for delta (relative) transform.
- **use_delta_transform**: Whether to convert actions to relative to current state for those dimensions.
- **state_key**: Dataset key for state (e.g. `observation.state`) used when applying delta transform.
**Tokenizer & evaluation**
- **action_tokenizer_path**: Path or HuggingFace repo id of the trained tokenizer (e.g. `outputs/wavetoken`).
- **max_chunks_for_reconstruction**: Max number of chunks to use for reconstruction and timing (default `500`) to limit runtime.
### Main parameters
| parameter | default | description |
| -------------------------------- | ---------------------------- | ------------------------------------------------ |
| **action_tokenizer_path** | (required) | Path or Hub id of the trained action tokenizer. |
| **repo_id** | (required) | LeRobot dataset repo id. |
| **action_horizon** | `10` | Future steps per chunk. |
| **encoded_dims** | `0:6` | Dimension ranges to encode (e.g. `0:6,7:14`). |
| **normalization_mode** | `QUANTILES` | Normalization mode for actions. |
| **max_episodes** | all | Max episodes to load. |
| **sample_fraction** | `0.2` | Fraction of chunks sampled per episode. |
| **max_chunks_for_reconstruction**| `500` | Chunks used for reconstruction and timing. |
| **output_dir** | `outputs/action_tokenizer_benchmark` | Directory for results JSON. |
## Metrics
**Reconstruction (lower is better)**
- **reconstruction_mae**: Mean absolute error between original and decoded action chunks.
- **reconstruction_mse**: Mean squared error.
- **reconstruction_rmse**: Root mean squared error.
- **reconstruction_max_abs_error**: Maximum absolute error over all dimensions and samples.
- **per_dimension_mae**: MAE per action dimension (list of length `action_dim`).
**Compression**
- **compression_ratio**: Ratio (action_horizon × action_dim) / mean number of tokens. Higher means more compression.
- **mean_token_length**, **std_token_length**: Mean and standard deviation of token count per chunk.
- **min_token_length**, **max_token_length**: Min and max token count.
- **p50_token_length**, **p99_token_length**: 50th and 99th percentile token counts.
**Timing (seconds per chunk)**
- **mean_encode_time_sec**: Mean time to encode one chunk.
- **mean_decode_time_sec**: Mean time to decode one chunk.
The JSON output also includes **num_chunks_evaluated** and **total_chunks_available** for context.
## How the benchmark works
1. **Load dataset**: LeRobot dataset is loaded for the given `repo_id` and `root`.
2. **Build action chunks**: For each episode (up to `max_episodes`), action chunks are built with the same logic as `lerobot-train-tokenizer`: sliding window of length `action_horizon`, optional delta transform, and per-episode sampling with `sample_fraction`.
3. **Extract and normalize**: Only `encoded_dims` are kept. Normalization is applied using the datasets action stats when available, according to `normalization_mode`.
4. **Encode / decode**: A random sample of chunks (size `max_chunks_for_reconstruction`) is encoded and then decoded with the tokenizer. Encode and decode times are recorded per chunk.
5. **Compute metrics**: Reconstruction metrics are computed between original and decoded chunks; compression and timing stats are aggregated.
6. **Save results**: A JSON file is written to `output_dir` with name `{timestamp}_{repo_id}_action_tokenizer_results.json`, containing the full config and all metrics.
The pipeline (chunking, dimensions, normalization, delta) must match how the tokenizer was trained; otherwise reconstruction error can be large or the tokenizer may raise.
## Caveats
- The tokenizers **action_horizon** and **action_dim** (and optionally DCT settings) are fixed at training time. The benchmark infers dimensions from the dataset and encoded dims; the tokenizer path must correspond to a model trained with the same horizon and encoded dimensions.
- Reconstruction is evaluated in **normalized space** (the same space the tokenizer sees). For interpretation in raw action space, you would need to invert normalization outside this script.
- Only one tokenizer and one dataset are evaluated per run. To compare tokenizers or datasets, run the script multiple times and compare the saved JSON files.
## Example
Quick run with a local tokenizer and a small number of episodes:
```bash
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
--action-tokenizer-path=outputs/wavetoken \
--repo-id=lerobot/pusht \
--action-horizon=10 \
--max-episodes=50 \
--output-dir=outputs/action_tokenizer_benchmark
```
With delta transform and custom encoded dimensions:
```bash
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
--action-tokenizer-path=outputs/wavetoken \
--repo-id=lerobot/pusht \
--action-horizon=10 \
--encoded-dims=0:6,7:14 \
--delta-dims=0,1,2,3,4,5 \
--use-delta-transform \
--normalization-mode=QUANTILES \
--max-chunks-for-reconstruction=500 \
--output-dir=outputs/action_tokenizer_benchmark
```
Results are written to e.g. `outputs/action_tokenizer_benchmark/2026-02-12_14-30-00_lerobot_pusht_action_tokenizer_results.json`.
## Results
Results are stored as JSON in the directory given by `--output-dir` (default: `outputs/action_tokenizer_benchmark`). Each file contains:
- **config**: All script arguments (tokenizer path, repo_id, action_horizon, encoded_dims, normalization_mode, etc.) for reproducibility.
- **metrics**: All reconstruction, compression, and timing metrics described above.
To compare runs, load and diff or aggregate these JSON files with your own scripts or notebooks.

View File

@@ -0,0 +1,442 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Benchmark action tokenization: reconstruction error, compression ratio, and timing.
Loads action chunks from a LeRobot dataset, encodes/decodes them with a trained action
tokenizer, and reports:
- Reconstruction: MAE, MSE, RMSE, max absolute error, per-dimension MAE
- Jerk: mean absolute jerk (original and reconstructed), jerk reconstruction MAE
- Compression: ratio (input size / mean tokens), token length stats
- Timing: mean encode/decode time per chunk
Results are saved to outputs/action_tokenizer_benchmark/<timestamp>_results.json.
Example:
```bash
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
--action-tokenizer-path=outputs/wavetoken \
--repo-id=lerobot/pusht \
--action-horizon=10 \
--max-episodes=50 \
--output-dir=outputs/action_tokenizer_benchmark
```
"""
import argparse
import json
import time
from pathlib import Path
import numpy as np
from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_STATE
# Optional: use same helpers as train script if we want to avoid duplication
from lerobot.scripts.lerobot_train_tokenizer import (
apply_normalization,
process_episode,
)
def load_action_chunks(
repo_id: str,
root: str | None,
action_horizon: int,
max_episodes: int | None,
sample_fraction: float,
encoded_dims: str,
delta_dims: str | None,
use_delta_transform: bool,
state_key: str,
normalization_mode: NormalizationMode,
):
"""Load and normalize action chunks from a LeRobot dataset (same pipeline as training)."""
dataset = LeRobotDataset(repo_id=repo_id, root=root)
num_episodes = dataset.num_episodes
if max_episodes is not None:
num_episodes = min(max_episodes, num_episodes)
# Parse encoded dims
encoded_dim_ranges = []
for range_str in encoded_dims.split(","):
start, end = map(int, range_str.strip().split(":"))
encoded_dim_ranges.append((start, end))
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
delta_dim_list = None
if delta_dims is not None and delta_dims.strip():
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
all_chunks = []
for ep_idx in range(num_episodes):
chunks = process_episode(
(
dataset,
ep_idx,
action_horizon,
delta_dim_list,
sample_fraction,
state_key,
use_delta_transform,
)
)
if chunks is not None:
all_chunks.append(chunks)
if not all_chunks:
raise ValueError("No action chunks collected. Check action_horizon and dataset.")
all_chunks = np.concatenate(all_chunks, axis=0)
# Extract encoded dimensions only
encoded_chunks = []
for start, end in encoded_dim_ranges:
encoded_chunks.append(all_chunks[:, :, start:end])
encoded_chunks = np.concatenate(encoded_chunks, axis=-1)
# Normalize
norm_stats = dataset.meta.stats
if norm_stats is not None and ACTION in norm_stats:
action_stats = norm_stats[ACTION]
encoded_dim_indices = []
for start, end in encoded_dim_ranges:
encoded_dim_indices.extend(range(start, end))
encoded_dim_indices = np.array(encoded_dim_indices)
encoded_stats = {}
for stat_name, stat_values in action_stats.items():
if isinstance(stat_values, (list, np.ndarray)):
stat_array = np.array(stat_values)
if len(stat_array) > max(encoded_dim_indices):
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
if encoded_stats:
try:
encoded_chunks = apply_normalization(
encoded_chunks, encoded_stats, normalization_mode, eps=1e-8
)
except ValueError:
pass
return encoded_chunks, total_encoded_dims, action_horizon, dataset.repo_id
def compute_reconstruction_metrics(original: np.ndarray, reconstructed: np.ndarray):
"""Compute reconstruction error metrics (original and reconstructed same shape [N, T, D])."""
diff = reconstructed - original
mae = float(np.mean(np.abs(diff)))
mse = float(np.mean(diff**2))
rmse = float(np.sqrt(mse))
max_abs_err = float(np.max(np.abs(diff)))
# Per-dimension MAE (over N and T)
per_dim_mae = np.mean(np.abs(diff), axis=(0, 1))
per_dim_mae = per_dim_mae.tolist()
return {
"reconstruction_mae": mae,
"reconstruction_mse": mse,
"reconstruction_rmse": rmse,
"reconstruction_max_abs_error": max_abs_err,
"per_dimension_mae": per_dim_mae,
}
def compute_jerk_metrics(original: np.ndarray, reconstructed: np.ndarray) -> dict:
"""Compute jerk (3rd derivative of action w.r.t. time) metrics.
Args:
original: Action chunks [N, T, D].
reconstructed: Reconstructed action chunks [N, T, D].
Returns:
Dict with mean absolute jerk for original, reconstructed, and jerk reconstruction MAE.
"""
# Jerk = 3rd discrete difference along time axis; need T >= 4
if original.shape[1] < 4:
return {}
jerk_orig = np.diff(original, n=3, axis=1) # (N, T-3, D)
jerk_recon = np.diff(reconstructed, n=3, axis=1)
mae_jerk_orig = float(np.mean(np.abs(jerk_orig)))
mae_jerk_recon = float(np.mean(np.abs(jerk_recon)))
jerk_reconstruction_mae = float(np.mean(np.abs(jerk_recon - jerk_orig)))
return {
"jerk_mae_original": mae_jerk_orig,
"jerk_mae_reconstructed": mae_jerk_recon,
"jerk_reconstruction_mae": jerk_reconstruction_mae,
}
def run_benchmark(
action_chunks: np.ndarray,
action_horizon: int,
action_dim: int,
tokenizer_path: str,
max_chunks_for_reconstruction: int | None = 500,
):
"""Encode/decode action chunks and compute metrics."""
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
n_chunks = len(action_chunks)
sample_size = n_chunks
if max_chunks_for_reconstruction is not None:
sample_size = min(max_chunks_for_reconstruction, n_chunks)
rng = np.random.RandomState(42)
indices = rng.choice(n_chunks, size=sample_size, replace=False)
sample_chunks = action_chunks[indices]
# Encode
token_lengths = []
encode_times = []
all_tokens = []
for i in range(len(sample_chunks)):
chunk = sample_chunks[i : i + 1]
t0 = time.perf_counter()
tokens = processor(chunk)[0]
encode_times.append(time.perf_counter() - t0)
if isinstance(tokens, list):
token_lengths.append(len(tokens))
all_tokens.append(tokens)
else:
n = tokens.shape[0] if hasattr(tokens, "shape") else len(tokens)
token_lengths.append(n)
all_tokens.append(tokens.tolist() if hasattr(tokens, "tolist") else list(tokens))
# Decode (processor keeps time_horizon/action_dim from encode)
decoded_list = []
decode_times = []
for i, tok_list in enumerate(all_tokens):
t0 = time.perf_counter()
recon = processor.decode(
[tok_list],
time_horizon=action_horizon,
action_dim=action_dim,
)
decode_times.append(time.perf_counter() - t0)
decoded_list.append(recon)
decoded = np.concatenate(decoded_list, axis=0)
# Reconstruction metrics
metrics = compute_reconstruction_metrics(sample_chunks, decoded)
# Jerk metrics (3rd derivative along time)
jerk_metrics = compute_jerk_metrics(sample_chunks, decoded)
metrics.update(jerk_metrics)
# Compression
token_lengths = np.array(token_lengths)
input_size = action_horizon * action_dim
compression_ratio = input_size / float(np.mean(token_lengths))
metrics["compression_ratio"] = compression_ratio
metrics["mean_token_length"] = float(np.mean(token_lengths))
metrics["std_token_length"] = float(np.std(token_lengths))
metrics["min_token_length"] = int(np.min(token_lengths))
metrics["max_token_length"] = int(np.max(token_lengths))
metrics["p50_token_length"] = float(np.percentile(token_lengths, 50))
metrics["p99_token_length"] = float(np.percentile(token_lengths, 99))
# Timing (seconds per chunk)
metrics["mean_encode_time_sec"] = float(np.mean(encode_times))
metrics["mean_decode_time_sec"] = float(np.mean(decode_times))
metrics["num_chunks_evaluated"] = sample_size
metrics["total_chunks_available"] = n_chunks
return metrics
def main(
action_tokenizer_path: str,
repo_id: str,
root: str | None = None,
action_horizon: int = 10,
max_episodes: int | None = 100,
sample_fraction: float = 0.2,
encoded_dims: str = "0:6",
delta_dims: str | None = None,
use_delta_transform: bool = False,
state_key: str = OBS_STATE,
normalization_mode: str = "QUANTILES",
max_chunks_for_reconstruction: int | None = 500,
output_dir: str | None = None,
):
if output_dir is None:
output_dir = "outputs/action_tokenizer_benchmark"
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
try:
norm_mode = NormalizationMode(normalization_mode)
except ValueError:
norm_mode = NormalizationMode.QUANTILES
print("Loading action chunks...")
encoded_chunks, action_dim, horizon, _ = load_action_chunks(
repo_id=repo_id,
root=root,
action_horizon=action_horizon,
max_episodes=max_episodes,
sample_fraction=sample_fraction,
encoded_dims=encoded_dims,
delta_dims=delta_dims,
use_delta_transform=use_delta_transform,
state_key=state_key,
normalization_mode=norm_mode,
)
print(f"Loaded {len(encoded_chunks)} chunks, shape {encoded_chunks.shape} (H={horizon}, D={action_dim})")
print("Running tokenizer benchmark...")
metrics = run_benchmark(
action_chunks=encoded_chunks,
action_horizon=horizon,
action_dim=action_dim,
tokenizer_path=action_tokenizer_path,
max_chunks_for_reconstruction=max_chunks_for_reconstruction,
)
# Attach config for reproducibility
results = {
"config": {
"action_tokenizer_path": action_tokenizer_path,
"repo_id": repo_id,
"action_horizon": action_horizon,
"max_episodes": max_episodes,
"sample_fraction": sample_fraction,
"encoded_dims": encoded_dims,
"delta_dims": delta_dims,
"use_delta_transform": use_delta_transform,
"state_key": state_key,
"normalization_mode": normalization_mode,
},
"metrics": metrics,
}
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
safe_repo = repo_id.replace("/", "_")
out_file = output_path / f"{timestamp}_{safe_repo}_action_tokenizer_results.json"
with open(out_file, "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {out_file}")
print("Metrics:")
for k, v in metrics.items():
if isinstance(v, list):
print(f" {k}: (length {len(v)})")
else:
print(f" {k}: {v}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark action tokenization (reconstruction error, compression, timing)."
)
parser.add_argument(
"--action-tokenizer-path",
type=str,
required=True,
help="Path or HuggingFace repo id of the trained action tokenizer (e.g. outputs/wavetoken).",
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="LeRobot dataset repo id (e.g. lerobot/pusht).",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="Root directory for LeRobot datasets.",
)
parser.add_argument(
"--action-horizon",
type=int,
default=10,
help="Number of future steps per action chunk.",
)
parser.add_argument(
"--max-episodes",
type=int,
default=None,
help="Max episodes to use (default: all).",
)
parser.add_argument(
"--sample-fraction",
type=float,
default=0.2,
help="Fraction of chunks to sample per episode.",
)
parser.add_argument(
"--encoded-dims",
type=str,
default="0:6",
help="Dimension ranges to encode (e.g. 0:6,7:14).",
)
parser.add_argument(
"--delta-dims",
type=str,
default=None,
help="Comma-separated dimensions for delta transform.",
)
parser.add_argument(
"--use-delta-transform",
action="store_true",
help="Apply delta (relative) transform to specified dimensions.",
)
parser.add_argument(
"--state-key",
type=str,
default=OBS_STATE,
help="Dataset key for state (for delta transform).",
)
parser.add_argument(
"--normalization-mode",
type=str,
default="QUANTILES",
choices=[m.value for m in NormalizationMode],
help="Normalization mode for actions.",
)
parser.add_argument(
"--max-chunks-for-reconstruction",
type=int,
default=500,
help="Max chunks to use for reconstruction metrics (default: 500).",
)
parser.add_argument(
"--output-dir",
type=str,
default="outputs/action_tokenizer_benchmark",
help="Directory to save results JSON (default: outputs/action_tokenizer_benchmark).",
)
args = parser.parse_args()
main(
action_tokenizer_path=args.action_tokenizer_path,
repo_id=args.repo_id,
root=args.root,
action_horizon=args.action_horizon,
max_episodes=args.max_episodes,
sample_fraction=args.sample_fraction,
encoded_dims=args.encoded_dims,
delta_dims=args.delta_dims,
use_delta_transform=args.use_delta_transform,
state_key=args.state_key,
normalization_mode=args.normalization_mode,
max_chunks_for_reconstruction=args.max_chunks_for_reconstruction,
output_dir=args.output_dir,
)

View File

@@ -7,8 +7,6 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: cameras
title: Cameras
- local: bring_your_own_policies
title: Bring Your Own Policies
- local: integrate_hardware
@@ -29,6 +27,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
title: "Datasets"
- sections:
- local: act
@@ -99,11 +99,19 @@
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
- local: omx
title: OMX
- local: openarm
title: OpenArm
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
@@ -113,6 +121,8 @@
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
- local: damiao
title: Damiao Motors and CAN Bus
title: "Resources"
- sections:
- local: contributing

View File

@@ -195,6 +195,7 @@ client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address="localhost:8080",
policy_device="mps",
client_device="cpu",
policy_type="smolvla",
pretrained_name_or_path="<user>/smolvla_async",
chunk_size_threshold=0.5,

View File

@@ -1,12 +1,22 @@
# Cameras
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
LeRobot offers multiple options for video capture:
### Finding your camera
| Class | Supported Cameras |
| ----------------- | ----------------------------------- |
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
| `ZMQCamera` | Network-connected cameras |
| `RealSenseCamera` | Intel RealSense (with depth) |
| `Reachy2Camera` | Reachy 2 robot cameras |
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
> [!TIP]
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
To find the camera indices of the cameras plugged into your system, run the following script:
### Find your camera
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
```bash
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
@@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
The output will look something like this if you have two cameras connected:
```
```bash
--- Detected Cameras ---
Camera #0:
Name: OpenCV Camera @ 0
@@ -33,13 +43,37 @@ Camera #0:
> [!WARNING]
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
## Use Cameras
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
Below are two examples, demonstrating how to work with the API.
## Use cameras
- **Asynchronous frame capture** using an OpenCV-based camera
### Frame access modes
All camera classes implement three access modes for capturing frames:
| Method | Behavior | Blocks? | Best For |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
### Usage examples
The following examples show how to use the camera API to configure and capture frames from different camera types.
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
- **Color and depth capture** using an Intel RealSense camera
> [!WARNING]
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
>
> ```python
> with OpenCVCamera(config) as camera:
> ...
> ```
>
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
<hfoptions id="shell_restart">
<hfoption id="Open CV Camera">
@@ -60,16 +94,30 @@ config = OpenCVCameraConfig(
)
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
camera = OpenCVCamera(config)
camera.connect()
with OpenCVCamera(config) as camera:
# Read a frame synchronously — blocks until hardware delivers a new frame
frame = camera.read()
print(f"read() call returned frame with shape:", frame.shape)
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"async_read call returned frame {i} with shape:", frame.shape)
except TimeoutError as e:
print(f"No frame received within timeout: {e}")
# Instantly return a frame - returns the most recent frame captured by the camera
try:
initial_frame = camera.read_latest(max_age_ms=1000)
for i in range(10):
frame = camera.read_latest(max_age_ms=1000)
print(f"read_latest call returned frame {i} with shape:", frame.shape)
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
except TimeoutError as e:
print(f"Frame too old: {e}")
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
try:
for i in range(10):
frame = camera.async_read(timeout_ms=200)
print(f"Async frame {i} shape:", frame.shape)
finally:
camera.disconnect()
```
<!-- prettier-ignore-end -->
@@ -111,10 +159,10 @@ finally:
</hfoption>
</hfoptions>
## Use your phone
## Use your phone's camera
<hfoptions id="use phone">
<hfoption id="Mac">
<hfoption id="iPhone & macOS">
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
@@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
Your iPhone should be detected automatically when running the camera setup script in the next section.
</hfoption>
<hfoption id="Linux">
<hfoption id="OBS virtual camera">
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
<!-- prettier-ignore-start -->
```python
```bash
sudo apt install v4l2loopback-dkms v4l-utils
```
<!-- prettier-ignore-end -->
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
3. _Download and install [OBS Studio](https://obsproject.com)_.
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
5. _Start OBS Studio_.
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio
```
<!-- prettier-ignore-end -->
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
<!-- prettier-ignore-start -->
```python
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
```
<!-- prettier-ignore-end -->
5. _Start OBS Studio_. Launch with:
<!-- prettier-ignore-start -->
```python
flatpak run com.obsproject.Studio
```
<!-- prettier-ignore-end -->
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
9. _Verify the virtual camera setup and resolution_.
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
```bash
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
```
You should see `VirtualCam` listed and resolution `640x480`.
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
<!-- prettier-ignore-start -->
```python
v4l2-ctl --list-devices
```
<!-- prettier-ignore-end -->
<details>
<summary><strong>Troubleshooting</strong></summary>
You should see an entry like:
> The virtual camera resolution is incorrect.
```
VirtualCam (platform:v4l2loopback-000):
/dev/video1
```
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
<!-- prettier-ignore-start -->
```python
v4l2-ctl -d /dev/video1 --get-fmt-video
```
<!-- prettier-ignore-end -->
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
You should see an entry like:
```
>>> Format Video Capture:
>>> Width/Height : 640/480
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
```
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
If everything is set up correctly, you can proceed with the rest of the tutorial.
</details>
</hfoption>
</hfoptions>
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.

165
docs/source/damiao.mdx Normal file
View File

@@ -0,0 +1,165 @@
# Damiao Motors and CAN Bus
This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication.
Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux.
## Linux CAN Setup
Before using Damiao motors, you need to set up the CAN interface on your Linux system.
### Install CAN Utilities
```bash
sudo apt-get install can-utils
```
### Configure CAN Interface (Manual)
For standard CAN FD (recommended for OpenArms):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on
sudo ip link set can0 up
```
For standard CAN (without FD):
```bash
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000
sudo ip link set can0 up
```
### Configure CAN Interface (Using LeRobot)
LeRobot provides a utility script to setup and test CAN interfaces:
```bash
# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses)
lerobot-setup-can --mode=setup --interfaces=can0,can1
```
## Debugging CAN Communication
Use the built-in debug tools to test motor communication:
```bash
# Test motors on all interfaces
lerobot-setup-can --mode=test --interfaces=can0,can1
# Run speed/latency test
lerobot-setup-can --mode=speed --interfaces=can0
```
The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output:
```
can0: UP (CAN FD)
Motor 0x01 (joint_1): ✓ FOUND
→ Response 0x11 [FD]: 00112233...
Motor 0x02 (joint_2): ✓ FOUND
Motor 0x03 (joint_3): ✗ No response
...
Summary: 2/8 motors found
```
## Usage
### Basic Setup
```python
from lerobot.motors import Motor
from lerobot.motors.damiao import DamiaoMotorsBus
# Define your motors with send/receive CAN IDs
motors = {
"joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11),
"joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12),
"joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13),
}
# Create the bus
bus = DamiaoMotorsBus(
port="can0", # Linux socketcan interface
motors=motors,
)
# Connect
bus.connect()
```
### Reading Motor States
```python
# Read single motor position (degrees)
position = bus.read("Present_Position", "joint_1")
# Read from multiple motors
positions = bus.sync_read("Present_Position") # All motors
positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"])
# Read all states at once (position, velocity, torque)
states = bus.sync_read_all_states()
# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
```
### Writing Motor Commands
```python
# Enable torque
bus.enable_torque()
# Set goal position (degrees)
bus.write("Goal_Position", "joint_1", 45.0)
# Set positions for multiple motors
bus.sync_write("Goal_Position", {
"joint_1": 45.0,
"joint_2": -30.0,
"joint_3": 90.0,
})
# Disable torque
bus.disable_torque()
```
## Configuration Options
| Parameter | Default | Description |
| -------------- | --------- | ----------------------------------------------------------- |
| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
Each motor requires:
- `id`: CAN ID for sending commands
- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`)
- `recv_id`: CAN ID for receiving responses
OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number.
## Troubleshooting
### No Response from Motors
1. **Check power**
2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections
3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs
4. **Test CAN interface**: Run `candump can0` to see if messages are being received
5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0`
### Motor Timeout Parameter
If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value.
### Verify CAN FD Status
```bash
ip -d link show can0 | grep fd
```

View File

@@ -0,0 +1,278 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor.tokenizer_processor import TokenizerProcessor
from lerobot.processor.pipeline import ProcessorPipeline
# Create a tokenizer processor
tokenizer_processor = TokenizerProcessor(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation

View File

@@ -1,5 +1,11 @@
# EarthRover Mini Plus
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Earth_Rover_Mini_5_240c9adc-4f9e-44b7-982f-5d1dc24af1d8.png.webp"
alt="EarthRover Mini Plus"
width="70%"
/>
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
## What You Need

View File

@@ -1,13 +1,15 @@
# Installation
## Install [`miniforge`](https://conda-forge.org/download/)
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Environment Setup
## Step 2: Environment Setup
Create a virtual environment with Python 3.10, using conda:
@@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
## Install LeRobot 🤗
## Step 3: Install LeRobot 🤗
### From Source

View File

@@ -1,5 +1,11 @@
# LeKiwi
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/1740517739083.jpeg"
alt="LeKiwi"
width="70%"
/>
In the steps below, we explain how to assemble the LeKiwi mobile robot.
## Source the parts

View File

@@ -42,6 +42,7 @@ lerobot-eval \
```
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run in total.

197
docs/source/omx.mdx Normal file
View File

@@ -0,0 +1,197 @@
## Order and Assemble the parts
First, assemble the OMX hardware following the official assembly guide.
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
## Install LeRobot 🤗
To install LeRobot, follow our [Installation Guide](./installation)
In addition to these instructions, you need to install the Dynamixel SDK:
```bash
pip install -e ".[dynamixel]"
```
## Connect the robot
To find the port for each bus servo adapter, run this script:
```bash
lerobot-find-port
```
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
<hfoptions id="find_port">
<hfoption id="Mac">
Example output on macOS:
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
Remove the USB cable from your MotorsBus and press Enter when done.
[...Disconnect corresponding leader or follower arm and press Enter...]
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the USB cable.
```
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
</hfoption>
<hfoption id="Linux">
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
#### 1. Find your device serial numbers
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
```bash
ls -l /dev/serial/by-id/
```
You will see output similar to:
```bash
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
```
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
Leader serial: `67E1ED68503059384C2E3120FF092234`
#### 2. Create the udev rule
Create a new udev rule file:
```bash
sudo nano /etc/udev/rules.d/99-omx.rules
```
Paste the following lines, replacing the serial numbers with the values you found above:
```bash
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
```
Save the file and reload udev rules:
```bash
sudo udevadm control --reload-rules
sudo udevadm trigger
```
Now unplug and replug both devices once.
#### 3. Verify the symlinks
Check that the persistent device names exist:
```bash
ls -l /dev/omx_follower /dev/omx_leader
```
You should see them pointing to ttyACM\* devices:
```bash
/dev/omx_follower -> ttyACM*
/dev/omx_leader -> ttyACM*
```
These names remain stable across reboots and reconnections.
</hfoption>
</hfoptions>
## Teleoperate
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
<hfoptions id="teleoperate">
<hfoption id="Mac">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=<your_follower_port> \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=<your_leader_port> \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
<hfoption id="Linux">
### Teleoperate without camera
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm
```
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
### Teleoperate with camera
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
```bash
lerobot-teleoperate \
--robot.type=omx_follower \
--robot.port=/dev/omx_follower \
--robot.id=omx_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
--teleop.type=omx_leader \
--teleop.port=/dev/omx_leader \
--teleop.id=omx_leader_arm \
--display_data=true
```
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own.
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).

276
docs/source/openarm.mdx Normal file
View File

@@ -0,0 +1,276 @@
# OpenArm
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
## What's Unique?
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
## Platform Requirements
<Tip warning={true}>
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
does not have macOS drivers and has not been tested on Windows.
</Tip>
## Safety Guide
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
- **Emergency stop**: Know the location and operation of the emergency stop device
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
## Hardware Setup
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
- Bill of materials and sourcing
- 3D printing instructions
- Mechanical assembly
- Electrical wiring
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
## CAN Bus Setup
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
Quick setup:
```bash
# Setup CAN interfaces
lerobot-setup-can --mode=setup --interfaces=can0,can1
# Test motor communication
lerobot-setup-can --mode=test --interfaces=can0,can1
```
## Install LeRobot 🤗
Follow our [Installation Guide](./installation), then install the Damiao motor support:
```bash
pip install -e ".[damiao]"
```
## Usage
### Follower Arm (Robot)
<hfoptions id="follower">
<hfoption id="Command">
```bash
lerobot-calibrate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_openarm_follower
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
config = OpenArmFollowerConfig(
port="can0",
side="right", # or "left" for left arm
id="my_openarm_follower",
)
follower = OpenArmFollower(config)
follower.connect()
# Read current state
obs = follower.get_observation()
print(obs)
# Send action (position in degrees)
action = {
"joint_1.pos": 0.0,
"joint_2.pos": 0.0,
"joint_3.pos": 0.0,
"joint_4.pos": 45.0,
"joint_5.pos": 0.0,
"joint_6.pos": 0.0,
"joint_7.pos": 0.0,
"gripper.pos": 0.0,
}
follower.send_action(action)
follower.disconnect()
```
</hfoption>
</hfoptions>
### Leader Arm (Teleoperator)
The leader arm is used for teleoperation - manually moving it to control the follower arm.
<hfoptions id="leader">
<hfoption id="Command">
```bash
lerobot-calibrate \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_openarm_leader
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
config = OpenArmLeaderConfig(
port="can1",
id="my_openarm_leader",
manual_control=True, # Disable torque for manual movement
)
leader = OpenArmLeader(config)
leader.connect()
# Read current position (as action to send to follower)
action = leader.get_action()
print(action)
leader.disconnect()
```
</hfoption>
</hfoptions>
### Teleoperation
To teleoperate OpenArm with leader-follower control:
```bash
lerobot-teleoperate \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader
```
### Bimanual Teleoperation
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
```bash
lerobot-teleoperate \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.id=my_bimanual_follower \
--teleop.type=bi_openarm_leader \
--teleop.left_arm_config.port=can2 \
--teleop.right_arm_config.port=can3 \
--teleop.id=my_bimanual_leader
```
### Recording Data
To record a dataset during teleoperation:
```bash
lerobot-record \
--robot.type=openarm_follower \
--robot.port=can0 \
--robot.side=right \
--robot.id=my_follower \
--teleop.type=openarm_leader \
--teleop.port=can1 \
--teleop.id=my_leader \
--repo-id=my_hf_username/my_openarm_dataset \
--fps=30 \
--num-episodes=10
```
## Configuration Options
### Follower Configuration
| Parameter | Default | Description |
| --------------------- | --------- | ---------------------------------------------------------- |
| `port` | - | CAN interface (e.g., `can0`) |
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
| `max_relative_target` | `None` | Safety limit for relative target positions |
| `position_kp` | Per-joint | Position control proportional gains |
| `position_kd` | Per-joint | Position control derivative gains |
### Leader Configuration
| Parameter | Default | Description |
| ------------------ | --------- | ----------------------------------- |
| `port` | - | CAN interface (e.g., `can1`) |
| `manual_control` | `True` | Disable torque for manual movement |
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
## Motor Configuration
OpenArm uses Damiao motors with the following default configuration:
| Joint | Motor Type | Send ID | Recv ID |
| --------------------------- | ---------- | ------- | ------- |
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
| gripper | DM4310 | 0x08 | 0x18 |
## Troubleshooting
### No Response from Motors
1. Check power supply connections
2. Verify CAN wiring (CAN-H, CAN-L, GND)
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
### CAN Interface Not Found
Ensure the CAN interface is configured:
```bash
ip link show can0
```
## Resources
- [OpenArm Website](https://openarm.dev)
- [OpenArm Documentation](https://docs.openarm.dev)
- [OpenArm GitHub](https://github.com/enactic/openarm)
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
- [Damiao Motors and CAN Bus](./damiao)

View File

@@ -1,5 +1,18 @@
# SO-101
<div style="display: flex; align-items: center; gap: 10px;">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Follower.webp"
alt="SO-101"
width="60%"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/SO101_Leader.webp"
alt="SO-101"
width="60%"
/>
</div>
In the steps below, we explain how to assemble our flagship robot, the SO-101.
## Source the parts

View File

@@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy.
## Running in Simulation Mode (MuJoCo)
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
### Calibrate Exoskeleton Teleoperator
```bash
lerobot-calibrate \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
```
### Teleoperate in Simulation
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset in Simulation
```bash
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
---
## Running on Real Robot
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
### Start the Camera Server
On the robot, start the ZMQ image server:
```bash
python src/lerobot/cameras/zmq/image_server.py
```
Keep this running in a separate terminal for camera streaming during recording.
### Teleoperate Real Robot
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset on Real Robot
```bash
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
```
**Note**: Update `server_address` to match your robot's camera server IP.
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
---
## Additional Resources

View File

@@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig):
actions = dataset.hf_dataset.select_columns(ACTION)
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
try:
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
action["elbow_flex.pos"] -= 90
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
robot.disconnect()
dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
finally:
robot.disconnect()
if __name__ == "__main__":

View File

@@ -78,40 +78,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -120,24 +104,42 @@ def main():
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Save episode
dataset.save_episode()
recorded_episodes += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
recorded_episodes += 1
dataset.finalize()
dataset.push_to_hub()
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -74,40 +74,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_record")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
@@ -115,26 +98,44 @@ def main():
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Save episode
dataset.save_episode()
recorded_episodes += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
recorded_episodes += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -42,25 +42,27 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Send action to robot
_ = robot.send_action(action)
# Send action to robot
_ = robot.send_action(action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
robot.disconnect()
if __name__ == "__main__":

View File

@@ -43,13 +43,12 @@ def main():
keyboard.connect()
# Init rerun viewer
init_rerun(session_name="lekiwi_teleop", robot=robot, reset_time=True)
init_rerun(session_name="lekiwi_teleop")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting teleop loop...")
start = time.perf_counter()
while True:
t0 = time.perf_counter()
@@ -70,7 +69,7 @@ def main():
_ = robot.send_action(action)
# Visualize
log_rerun_data(observation=observation, action=action, log_time=time.perf_counter() - start)
log_rerun_data(observation=observation, action=action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -142,38 +142,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -182,24 +168,41 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -149,38 +149,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_record")
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
@@ -188,25 +173,43 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -73,32 +73,34 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
# Clean up
robot.disconnect()
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
finally:
# Clean up
robot.disconnect()
if __name__ == "__main__":

View File

@@ -89,13 +89,12 @@ def main():
teleop_device.connect()
# Init rerun viewer
init_rerun(session_name="phone_so100_teleop", robot=robot, reset_time=True)
init_rerun(session_name="phone_so100_teleop")
if not robot.is_connected or not teleop_device.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting teleop loop. Move your phone to teleoperate the robot...")
start = time.perf_counter()
while True:
t0 = time.perf_counter()
@@ -112,7 +111,7 @@ def main():
_ = robot.send_action(joint_action)
# Visualize
log_rerun_data(observation=phone_obs, action=joint_action, log_time=time.perf_counter() - start)
log_rerun_data(observation=phone_obs, action=joint_action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -142,38 +142,24 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="so100_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -182,24 +168,41 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -146,38 +146,23 @@ def main():
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
try:
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
@@ -185,25 +170,44 @@ def main():
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Save episode
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
# Save episode
dataset.save_episode()
episode_idx += 1
dataset.finalize()
dataset.push_to_hub()
finally:
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":

View File

@@ -74,32 +74,35 @@ def main():
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i])
for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
# Clean up
robot.disconnect()
finally:
# Clean up
robot.disconnect()
if __name__ == "__main__":

View File

@@ -94,10 +94,9 @@ def main():
leader.connect()
# Init rerun viewer
init_rerun(session_name="so100_so100_EE_teleop", robot=follower, reset_time=True)
init_rerun(session_name="so100_so100_EE_teleop")
print("Starting teleop loop...")
start = time.perf_counter()
while True:
t0 = time.perf_counter()
@@ -117,9 +116,7 @@ def main():
_ = follower.send_action(follower_joints_act)
# Visualize
log_rerun_data(
observation=leader_ee_act, action=follower_joints_act, log_time=time.perf_counter() - start
)
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -30,6 +30,7 @@ def main():
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
client_device="cpu",
policy_type="act",
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g

View File

@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.3"
version = "0.4.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
@@ -102,14 +102,20 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0"
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"matplotlib>=3.9.0,<4.0.0",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
@@ -147,7 +153,6 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
audio = ["sounddevice>=0.5.1,<0.6.0", "soundfile>=0.13.1,<0.14.0", "librosa>=0.11.0,<0.12.0", "torchaudio>=2.6.0,<2.10.0"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -176,7 +181,6 @@ all = [
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[audio]",
"lerobot[dev]",
"lerobot[test]",
"lerobot[video_benchmark]",
@@ -205,6 +209,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]
@@ -280,6 +285,7 @@ default.extend-ignore-identifiers-re = [
"thw",
"inpt",
"ROBOTIS",
"OT_VALUE"
]
# TODO: Uncomment when ready to use
@@ -354,9 +360,9 @@ ignore_errors = false
module = "lerobot.cameras.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.motors.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"

View File

@@ -29,7 +29,6 @@ Example:
print(lerobot.available_policies_per_env)
print(lerobot.available_robots)
print(lerobot.available_cameras)
print(lerobot.available_microphones)
print(lerobot.available_motors)
```
@@ -175,13 +174,6 @@ available_cameras = [
"intelrealsense",
]
# lists all available microphones from `lerobot/microphones`
available_microphones = [
"portaudio",
"touchlab",
"anyskin",
]
# lists all available motors from `lerobot/motors`
available_motors = [
"dynamixel",

View File

@@ -126,6 +126,12 @@ class RobotClientConfig:
# Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
client_device: str = field(
default="cpu",
metadata={
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
},
)
# Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
@@ -161,6 +167,9 @@ class RobotClientConfig:
if not self.policy_device:
raise ValueError("policy_device cannot be empty")
if not self.client_device:
raise ValueError("client_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
@@ -184,6 +193,7 @@ class RobotClientConfig:
"policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device,
"client_device": self.client_device,
"chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps,
"actions_per_chunk": self.actions_per_chunk,

View File

@@ -18,6 +18,7 @@ import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import torch
@@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging
Action = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
# observation as received from the robot (can be numpy arrays, floats, etc.)
RawObservation = dict[str, Any]
# observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor]

View File

@@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
action_tensor = action_tensor.detach().cpu()
"""5. Convert to TimedAction list"""
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()

View File

@@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--client_device=cpu \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
@@ -40,6 +41,7 @@ from collections.abc import Callable
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any
import draccus
import grpc
@@ -47,10 +49,6 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
from lerobot.processor import RobotAction
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -288,6 +286,21 @@ class RobotClient:
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
# Log device type of received actions
if len(timed_actions) > 0:
received_device = timed_actions[0].get_action().device.type
self.logger.debug(f"Received actions on device: {received_device}")
# Move actions to client_device (e.g., for downstream planners that need GPU)
client_device = self.config.client_device
if client_device != "cpu":
for timed_action in timed_actions:
if timed_action.get_action().device.type != client_device:
timed_action.action = timed_action.get_action().to(client_device)
self.logger.debug(f"Converted actions to device: {client_device}")
else:
self.logger.debug(f"Actions kept on device: {client_device}")
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
@@ -354,7 +367,7 @@ class RobotClient:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> RobotAction:
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
"""Reading and performing actions in local queue"""
# Lock only for queue operations

View File

@@ -13,5 +13,5 @@
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .utils import make_cameras_from_configs

View File

@@ -15,11 +15,12 @@
# limitations under the License.
import abc
import warnings
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
from .configs import CameraConfig, ColorMode
from .configs import CameraConfig
class Camera(abc.ABC):
@@ -30,20 +31,12 @@ class Camera(abc.ABC):
Manages basic camera properties (FPS, resolution) and core operations:
- Connection/disconnection
- Frame capture (sync/async)
- Frame capture (sync/async/latest)
Attributes:
fps (int | None): Configured frames per second
width (int | None): Frame width in pixels
height (int | None): Frame height in pixels
Example:
class MyCamera(Camera):
def __init__(self, config): ...
@property
def is_connected(self) -> bool: ...
def connect(self, warmup=True): ...
# Plus other required methods
"""
def __init__(self, config: CameraConfig):
@@ -56,6 +49,32 @@ class Camera(abc.ABC):
self.width: int | None = config.width
self.height: int | None = config.height
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
@@ -89,12 +108,10 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""Capture and return a single frame from the camera.
def read(self) -> NDArray[Any]:
"""Capture and return a single frame from the camera synchronously.
Args:
color_mode: Desired color mode for the output frame. If None,
uses the camera's default color mode.
This is a blocking call that will wait for the hardware and its SDK.
Returns:
np.ndarray: Captured frame as a numpy array.
@@ -103,17 +120,64 @@ class Camera(abc.ABC):
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
"""Asynchronously capture and return a single frame from the camera.
"""Return the most recent new frame.
This method retrieves the latest frame captured by the background thread.
If a new frame is already available in the buffer (captured since the last call),
it returns it immediately.
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
was already consumed by a previous `async_read` call.
Essentially, this method return the latest unconsumed frame, waiting if necessary
for a new one to arrive within the specified timeout.
Usage:
- Ideal for control loops where you want to ensure every processed frame
is fresh, effectively synchronizing your loop to the camera's FPS.
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
or if the camera is disconnected.
Args:
timeout_ms: Maximum time to wait for a frame in milliseconds.
Defaults to implementation-specific timeout.
timeout_ms: Maximum time to wait for a new frame in milliseconds.
Defaults to 200ms (0.2s).
Returns:
np.ndarray: Captured frame as a numpy array.
Raises:
TimeoutError: If no new frame arrives within `timeout_ms`.
"""
pass
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Usage:
Ideal for scenarios requiring zero latency or decoupled frequencies & when
we want a guaranteed frame, such as UI visualization, logging, or
non-critical monitoring.
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
NotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
warnings.warn(
f"{self.__class__.__name__}.read_latest() is not implemented. "
"Please override read_latest(); it will be required in future releases.",
FutureWarning,
stacklevel=2,
)
return self.async_read()
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnect from the camera and release resources."""

View File

@@ -25,6 +25,10 @@ class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
@@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum):
ROTATE_180 = 180
ROTATE_270 = -90
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
class Cv2Backends(int, Enum):
ANY = 0
V4L2 = 200
DSHOW = 700
PVAPI = 800
ANDROID = 1000
AVFOUNDATION = 1200
MSMF = 1400
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus

View File

@@ -32,10 +32,11 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_backend, get_cv2_rotation
from ..utils import get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
@@ -70,34 +71,24 @@ class OpenCVCamera(Camera):
Example:
```python
from lerobot.cameras.opencv import OpenCVCamera
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
# Basic usage with camera index 0
config = OpenCVCameraConfig(index_or_path=0)
camera = OpenCVCamera(config)
camera.connect()
# Read 1 frame synchronously
# Read 1 frame synchronously (blocking)
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# When done, properly disconnect the camera using
camera.disconnect()
# Example with custom settings
custom_config = OpenCVCameraConfig(
index_or_path='/dev/video0', # Or use an index
fps=30,
width=1280,
height=720,
color_mode=ColorMode.RGB,
rotation=Cv2Rotation.ROTATE_90
)
custom_camera = OpenCVCamera(custom_config)
# ... connect, read, disconnect ...
```
"""
@@ -123,10 +114,11 @@ class OpenCVCamera(Camera):
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = get_cv2_backend()
self.backend: int = config.backend
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
@@ -141,20 +133,23 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
Initializes the OpenCV VideoCapture object, sets desired camera properties
(FPS, width, height), and performs initial checks.
(FPS, width, height), starts the background reading thread and performs initial checks.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
@@ -170,15 +165,20 @@ class OpenCVCamera(Camera):
)
self._configure_capture_settings()
self._start_read_thread()
if warmup:
if warmup and self.warmup_s > 0:
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
self.async_read(timeout_ms=self.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
logger.info(f"{self} connected.")
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
@@ -196,11 +196,8 @@ class OpenCVCamera(Camera):
Raises:
RuntimeError: If the camera fails to set any of the specified properties
to the requested value.
DeviceNotConnectedError: If the camera is not connected when attempting
to configure settings.
DeviceNotConnectedError: If the camera is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
@@ -339,6 +336,18 @@ class OpenCVCamera(Camera):
return found_cameras_info
def _read_from_hardware(self) -> NDArray[Any]:
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
ret, frame = self.videocapture.read()
if not ret:
raise RuntimeError(f"{self} read failed (status={ret}).")
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -346,11 +355,6 @@ class OpenCVCamera(Camera):
This is a blocking call. It waits for the next available frame from the
camera hardware via OpenCV.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
@@ -362,34 +366,31 @@ class OpenCVCamera(Camera):
received frame dimensions don't match expectations before rotation.
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
ret, frame = self.videocapture.read()
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
processed_frame = self._postprocess_image(frame, color_mode)
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return processed_frame
return frame
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
Args:
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame.
@@ -399,11 +400,10 @@ class OpenCVCamera(Camera):
RuntimeError: If the raw frame dimensions do not match the configured
`width` and `height`.
"""
requested_color_mode = self.color_mode if color_mode is None else color_mode
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
h, w, c = image.shape
@@ -417,7 +417,7 @@ class OpenCVCamera(Camera):
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
processed_image = image
if requested_color_mode == ColorMode.RGB:
if self.color_mode == ColorMode.RGB:
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
@@ -431,7 +431,7 @@ class OpenCVCamera(Camera):
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -439,30 +439,37 @@ class OpenCVCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
color_image = self.read()
raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame)
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = color_image
self.latest_frame = processed_frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self._stop_read_thread()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
@@ -475,6 +482,12 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -482,6 +495,7 @@ class OpenCVCamera(Camera):
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is “best effort” under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -496,17 +510,14 @@ class OpenCVCamera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
f"Read thread alive: {self.thread.is_alive()}."
)
with self.frame_lock:
@@ -518,6 +529,41 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera and cleans up resources.
@@ -538,4 +584,9 @@ class OpenCVCamera(Camera):
self.videocapture.release()
self.videocapture = None
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")

View File

@@ -15,9 +15,9 @@
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
@CameraConfig.register_subclass("opencv")
@@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
@@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
backend: Cv2Backends = Cv2Backends.ANY
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
self.backend = Cv2Backends(self.backend)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(

View File

@@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig):
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.color_mode = ColorMode(self.color_mode)

View File

@@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
@@ -80,6 +81,8 @@ class Reachy2Camera(Camera):
self.config = config
self.color_mode = config.color_mode
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.cam_manager: CameraManager | None = None
@@ -121,16 +124,12 @@ class Reachy2Camera(Camera):
"""
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
This method retrieves the most recent frame available in Reachy 2's low-level software.
Returns:
np.ndarray: The captured frame as a NumPy array in the format
@@ -139,12 +138,14 @@ class Reachy2Camera(Camera):
"""
start_time = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
@@ -165,25 +166,27 @@ class Reachy2Camera(Camera):
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8)
raise RuntimeError(f"Internal error: No frame available for {self}.")
if self.config.color_mode == "rgb":
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if self.color_mode == ColorMode.RGB:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.latest_frame = frame
self.latest_timestamp = time.perf_counter()
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame.
This method retrieves the most recent frame available in Reachy 2's low-level software.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Same as read()
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
@@ -194,16 +197,40 @@ class Reachy2Camera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
frame = self.read()
return self.read()
if frame is None:
raise RuntimeError(f"Internal error: No frame available for {self}.")
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
return frame
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
tuple[NDArray, float]:
- The frame image (numpy array).
- The timestamp (time.perf_counter) when this frame was captured.
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return self.latest_frame
@check_if_not_connected
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -211,8 +238,6 @@ class Reachy2Camera(Camera):
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.cam_manager is not None:
self.cam_manager.disconnect()

View File

@@ -30,7 +30,8 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -72,15 +73,14 @@ class RealSenseCamera(Camera):
camera = RealSenseCamera(config)
camera.connect()
# Read 1 frame synchronously
# Read 1 frame synchronously (blocking)
color_image = camera.read()
print(color_image.shape)
# Read 1 frame asynchronously
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# When done, properly disconnect the camera using
camera.disconnect()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
# Example with depth capture and custom settings
custom_config = RealSenseCameraConfig(
@@ -133,7 +133,9 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_color_frame: NDArray[Any] | None = None
self.latest_depth_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -151,6 +153,7 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -158,14 +161,16 @@ class RealSenseCamera(Camera):
Initializes the RealSense pipeline, configures the required streams (color
and optionally depth), starts the pipeline, and validates the actual stream settings.
Args:
warmup (bool): If True, waits at connect() time until at least one valid frame
has been captured by the background thread. Defaults to True.
Raises:
DeviceAlreadyConnectedError: If the camera is already connected.
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
@@ -181,15 +186,18 @@ class RealSenseCamera(Camera):
) from e
self._configure_capture_settings()
self._start_read_thread()
if warmup:
time.sleep(
1
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.read()
time.sleep(0.1)
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
self.warmup_s = max(self.warmup_s, 1)
start_time = time.time()
while time.time() - start_time < self.warmup_s:
self.async_read(timeout_ms=self.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
logger.info(f"{self} connected.")
@@ -282,6 +290,7 @@ class RealSenseCamera(Camera):
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
@@ -291,8 +300,6 @@ class RealSenseCamera(Camera):
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
@@ -312,6 +319,7 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
@check_if_not_connected
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -319,9 +327,6 @@ class RealSenseCamera(Camera):
This is a blocking call. It waits for a coherent set of frames (depth)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
@@ -330,44 +335,50 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
"""
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.use_depth:
raise RuntimeError(
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
start_time = time.perf_counter()
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self.new_frame_event.clear()
_ = self.async_read(timeout_ms=10000)
with self.frame_lock:
depth_map = self.latest_depth_frame
if depth_map is None:
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
return depth_map
def _read_from_hardware(self):
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
if not ret or frame is None:
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
raise RuntimeError(f"{self} read failed (status={ret}).")
depth_frame = frame.get_depth_frame()
depth_map = np.asanyarray(depth_frame.get_data())
return frame
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
This is a blocking call. It waits for a coherent set of frames (color)
from the camera hardware via the RealSense pipeline.
Args:
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
Returns:
np.ndarray: The captured color frame as a NumPy array
(height, width, channels), processed according to `color_mode` and rotation.
@@ -378,39 +389,36 @@ class RealSenseCamera(Camera):
ValueError: If an invalid `color_mode` is requested.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if timeout_ms:
logger.warning(
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not ret or frame is None:
raise RuntimeError(f"{self} read failed (status={ret}).")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
color_frame = frame.get_color_frame()
color_image_raw = np.asanyarray(color_frame.get_data())
self.new_frame_event.clear()
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return color_image_processed
return frame
def _postprocess_image(
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
) -> NDArray[Any]:
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
Args:
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
uses the instance's default `self.color_mode`.
Returns:
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
@@ -421,9 +429,9 @@ class RealSenseCamera(Camera):
`width` and `height`.
"""
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
)
if depth_frame:
@@ -454,7 +462,7 @@ class RealSenseCamera(Camera):
On each iteration:
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame (thread-safe)
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -462,25 +470,41 @@ class RealSenseCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame()
color_frame = np.asanyarray(color_frame_raw.get_data())
processed_color_frame = self._postprocess_image(color_frame)
if self.use_depth:
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = color_image
self.latest_color_frame = processed_color_frame
if self.use_depth:
self.latest_depth_frame = processed_depth_frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
if failure_count <= 10:
failure_count += 1
logger.warning(f"Error reading frame in background thread for {self}: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self._stop_read_thread()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
@@ -498,7 +522,14 @@ class RealSenseCamera(Camera):
self.thread = None
self.stop_event = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -506,6 +537,7 @@ class RealSenseCamera(Camera):
This method retrieves the most recent color frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
It is “best effort” under high FPS.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -520,21 +552,18 @@ class RealSenseCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
f"Read thread alive: {self.thread.is_alive()}."
)
with self.frame_lock:
frame = self.latest_frame
frame = self.latest_color_frame
self.new_frame_event.clear()
if frame is None:
@@ -542,6 +571,42 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_color_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
@@ -565,4 +630,10 @@ class RealSenseCamera(Camera):
self.rs_pipeline = None
self.rs_profile = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")

View File

@@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from typing import cast
from lerobot.utils.import_utils import make_device_from_device_class
@@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)

View File

@@ -34,7 +34,8 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -45,6 +46,12 @@ logger = logging.getLogger(__name__)
class ZMQCamera(Camera):
"""
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
incoming JSON messages containing Base64 encoded images. It supports both
synchronous and asynchronous frame reading patterns.
Example usage:
```python
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
@@ -52,7 +59,16 @@ class ZMQCamera(Camera):
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
camera = ZMQCamera(config)
camera.connect()
frame = camera.read()
# Read 1 frame synchronously (blocking)
color_image = camera.read()
# Read 1 frame asynchronously (waits for new frame with a timeout)
async_image = camera.async_read()
# Get the latest frame immediately (no wait, returns timestamp)
latest_image, timestamp = camera.read_latest()
camera.disconnect()
```
"""
@@ -68,14 +84,17 @@ class ZMQCamera(Camera):
self.color_mode = config.color_mode
self.timeout_ms = config.timeout_ms
# ZMQ Context and Socket
self.context: zmq.Context | None = None
self.socket: zmq.Socket | None = None
self._connected = False
# Threading resources
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_timestamp: float | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -83,12 +102,17 @@ class ZMQCamera(Camera):
@property
def is_connected(self) -> bool:
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server."""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
"""Connect to ZMQ camera server.
Args:
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
logger.info(f"Connecting to {self}...")
@@ -103,17 +127,28 @@ class ZMQCamera(Camera):
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
self._connected = True
# Auto-detect resolution
# Auto-detect resolution if not provided
if self.width is None or self.height is None:
h, w = self.read().shape[:2]
# Read directly from hardware because the thread isn't running yet
temp_frame = self._read_from_hardware()
h, w = temp_frame.shape[:2]
self.height = h
self.width = w
logger.info(f"{self} resolution: {w}x{h}")
logger.info(f"{self} resolution detected: {w}x{h}")
self._start_read_thread()
logger.info(f"{self} connected.")
if warmup:
time.sleep(0.1)
# Ensure we have captured at least one frame via the thread
start_time = time.time()
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
self.async_read(timeout_ms=self.config.warmup_s * 1000)
time.sleep(0.1)
with self.frame_lock:
if self.latest_frame is None:
raise ConnectionError(f"{self} failed to capture frames during warmup.")
except Exception as e:
self._cleanup()
@@ -131,15 +166,14 @@ class ZMQCamera(Camera):
@staticmethod
def find_cameras() -> list[dict[str, Any]]:
"""ZMQ cameras require manual configuration (server address/port)."""
return []
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Read a single frame from the ZMQ camera.
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
"""
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
Returns:
np.ndarray: Decoded frame (height, width, 3)
def _read_from_hardware(self) -> NDArray[Any]:
"""
Reads a single frame directly from the ZMQ socket.
"""
if not self.is_connected or self.socket is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -147,6 +181,7 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except Exception as e:
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
@@ -176,42 +211,114 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
This is a blocking call. It waits for the next available frame from the
camera background thread.
Returns:
np.ndarray: Decoded frame (height, width, 3)
"""
start_time = time.perf_counter()
if color_mode is not None:
logger.warning(
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
self.new_frame_event.clear()
frame = self.async_read(timeout_ms=10000)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self) -> None:
while self.stop_event and not self.stop_event.is_set():
"""
Internal loop run by the background thread for asynchronous reading.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.")
failure_count = 0
while not self.stop_event.is_set():
try:
frame = self.read()
frame = self._read_from_hardware()
capture_time = time.perf_counter()
with self.frame_lock:
self.latest_frame = frame
self.latest_timestamp = capture_time
self.new_frame_event.set()
failure_count = 0
except DeviceNotConnectedError:
break
except TimeoutError:
pass
except Exception as e:
logger.warning(f"Read error: {e}")
except (TimeoutError, Exception) as e:
if failure_count <= 10:
failure_count += 1
logger.warning(f"Read error: {e}")
else:
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
def _start_read_thread(self) -> None:
if self.thread and self.thread.is_alive():
return
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, daemon=True)
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
self.thread.start()
time.sleep(0.1)
def _stop_read_thread(self) -> None:
if self.stop_event:
if self.stop_event is not None:
self.stop_event.set()
if self.thread and self.thread.is_alive():
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
"""Read latest frame asynchronously (non-blocking)."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
if not self.thread or not self.thread.is_alive():
self._start_read_thread()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms.
Returns:
np.ndarray: The latest captured frame.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
@@ -225,11 +332,54 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
memory buffer. The frame may be stale,
meaning it could have been captured a while ago (hanging camera scenario e.g.).
Returns:
NDArray[Any]: The frame image (numpy array).
Raises:
TimeoutError: If the latest frame is older than `max_age_ms`.
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
frame = self.latest_frame
timestamp = self.latest_timestamp
if frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return frame
def disconnect(self) -> None:
"""Disconnect from ZMQ camera."""
if not self.is_connected and not self.thread:
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
self._stop_read_thread()
if self.thread is not None:
self._stop_read_thread()
self._cleanup()
with self.frame_lock:
self.latest_frame = None
self.latest_timestamp = None
self.new_frame_event.clear()
logger.info(f"{self} disconnected.")

View File

@@ -29,12 +29,10 @@ class ZMQCameraConfig(CameraConfig):
camera_name: str = "zmq_camera"
color_mode: ColorMode = ColorMode.RGB
timeout_ms: int = 5000
warmup_s: int = 1
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
self.color_mode = ColorMode(self.color_mode)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")

View File

@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
"""
n_obs_steps: int = 1
@@ -151,12 +151,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
return {}
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property
def audio_features(self) -> dict[str, PolicyFeature]:
if not self.input_features:
return {}
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
@property
def action_feature(self) -> PolicyFeature | None:
if not self.output_features:

View File

@@ -20,7 +20,6 @@ from enum import Enum
class FeatureType(str, Enum):
STATE = "STATE"
VISUAL = "VISUAL"
AUDIO = "AUDIO"
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"

View File

@@ -26,8 +26,6 @@ import tqdm
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
DEFAULT_AUDIO_PATH,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -43,7 +41,7 @@ from lerobot.datasets.utils import (
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
@@ -112,28 +110,72 @@ def update_meta_data(
meta_idx,
data_idx,
videos_idx,
audios_idx,
):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
Adjusts all indices and timestamps to account for previously aggregated
data and videos in the destination dataset.
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
to correctly map source file indices to their destination locations.
Args:
df: DataFrame containing the metadata to be updated.
dst_meta: Destination dataset metadata.
meta_idx: Dictionary containing current metadata chunk and file indices.
data_idx: Dictionary containing current data chunk and file indices.
videos_idx: Dictionary containing current video indices and timestamps.
audios_idx: Dictionary containing current audio indices and timestamps.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
"""
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
# Update data file indices using source-to-destination mapping
# This is critical for handling datasets that are already results of a merge
data_src_to_dst = data_idx.get("src_to_dst", {})
if data_src_to_dst:
# Store original indices for lookup
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
df["_orig_data_file"] = df["data/file_index"].copy()
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
# This is much faster than per-row iteration for large metadata tables
mapping_index = pd.MultiIndex.from_tuples(
list(data_src_to_dst.keys()),
names=["chunk_index", "file_index"],
)
mapping_values = list(data_src_to_dst.values())
mapping_df = pd.DataFrame(
mapping_values,
index=mapping_index,
columns=["dst_chunk", "dst_file"],
)
# Construct a MultiIndex for each row based on original data indices
row_index = pd.MultiIndex.from_arrays(
[df["_orig_data_chunk"], df["_orig_data_file"]],
names=["chunk_index", "file_index"],
)
# Align mapping to rows; missing keys fall back to the default destination
reindexed = mapping_df.reindex(row_index)
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
)
# Assign mapped destination indices back to the DataFrame
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
df["data/file_index"] = reindexed["dst_file"].to_numpy()
# Clean up temporary columns
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
else:
# Fallback to simple offset (backward compatibility for single-file sources)
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
@@ -149,8 +191,7 @@ def update_meta_data(
if src_to_dst:
# Map each episode to its correct destination file and apply offset
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
# Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
@@ -166,8 +207,7 @@ def update_meta_data(
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
@@ -183,36 +223,6 @@ def update_meta_data(
# Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"])
for key, audio_idx in audios_idx.items():
# Store original audio file indices before updating
orig_chunk_col = f"audio/{key}/chunk_index"
orig_file_col = f"audio/{key}/file_index"
df["_orig_chunk"] = df[orig_chunk_col].copy()
df["_orig_file"] = df[orig_file_col].copy()
# Update chunk and file indices to point to destination
df[orig_chunk_col] = audio_idx["chunk"]
df[orig_file_col] = audio_idx["file"]
# Apply per-source-file timestamp offsets
src_to_offset = audio_idx.get("src_to_offset", {})
if src_to_offset:
# Apply offset based on original source file
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"audio/{key}/from_timestamp"] += offset
df.at[idx, f"audio/{key}/to_timestamp"] += offset
else:
# Fallback to simple offset (for backward compatibility)
df[f"audio/{key}/from_timestamp"] = (
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
)
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
# Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"])
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
@@ -227,7 +237,6 @@ def aggregate_datasets(
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
audio_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -245,7 +254,6 @@ def aggregate_datasets(
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
"""
logging.info("Start aggregate_datasets")
@@ -254,8 +262,6 @@ def aggregate_datasets(
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if audio_files_size_in_mb is None:
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
@@ -268,7 +274,6 @@ def aggregate_datasets(
)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
dst_meta = LeRobotDatasetMetadata.create(
repo_id=aggr_repo_id,
@@ -280,7 +285,6 @@ def aggregate_datasets(
chunks_size=chunk_size,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
audio_files_size_in_mb=audio_files_size_in_mb,
)
logging.info("Find all tasks")
@@ -292,18 +296,18 @@ def aggregate_datasets(
videos_idx = {
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
}
audios_idx = {
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
}
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
# Clear the src_to_dst mapping after processing each source dataset
# to avoid interference between different source datasets
data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
@@ -355,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
dst_file_durations = video_idx["dst_file_durations"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
# Convert to Python int to ensure consistent dict keys
src_chunk_idx = int(src_chunk_idx)
src_file_idx = int(src_file_idx)
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
@@ -371,7 +371,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
file_index=file_idx,
)
src_duration = get_media_duration_in_s(src_path, media_type="video")
src_duration = get_video_duration_in_s(src_path)
dst_key = (chunk_idx, file_idx)
if not dst_path.exists():
@@ -410,7 +410,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
current_dst_duration = dst_file_durations.get(dst_key, 0)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
concatenate_media_files(
concatenate_video_files(
[dst_path, src_path],
dst_path,
)
@@ -425,111 +425,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
return videos_idx
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
"""Aggregates audio files from a source dataset into the destination dataset.
Handles audio file concatenation and rotation based on file size limits.
Creates new audio files when size limits are exceeded.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
audio_idx: Dictionary tracking audio chunk and file indices.
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns:
dict: Updated audio_idx with current chunk and file indices.
"""
for key in audios_idx:
audios_idx[key]["episode_duration"] = 0
# Track offset for each source (chunk, file) pair
audios_idx[key]["src_to_offset"] = {}
for key, audio_idx in audios_idx.items():
unique_chunk_file_pairs = {
(chunk, file)
for chunk, file in zip(
src_meta.episodes[f"audio/{key}/chunk_index"],
src_meta.episodes[f"audio/{key}/file_index"],
strict=False,
)
}
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
chunk_idx = audio_idx["chunk"]
file_idx = audio_idx["file"]
current_offset = audio_idx["latest_duration"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
audio_key=key,
chunk_index=src_chunk_idx,
file_index=src_file_idx,
)
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
audio_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
src_duration = get_media_duration_in_s(src_path, media_type="audio")
if not dst_path.exists():
# Store offset before incrementing
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
audios_idx[key]["episode_duration"] += src_duration
current_offset += src_duration
continue
# Check file sizes before appending
src_size = get_file_size_in_mb(src_path)
dst_size = get_file_size_in_mb(dst_path)
if dst_size + src_size >= audio_files_size_in_mb:
# Rotate to a new file, this source becomes start of new destination
# So its offset should be 0
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
audio_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Reset offset for next file
current_offset = src_duration
else:
# Append to existing video file - use current accumulated offset
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
concatenate_media_files(
[dst_path, src_path],
dst_path,
)
current_offset += src_duration
audios_idx[key]["episode_duration"] += src_duration
audios_idx[key]["chunk"] = chunk_idx
audios_idx[key]["file"] = file_idx
return audios_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
and writes them to the destination with proper file rotation.
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
which is critical for correctly updating episode metadata when source datasets
have multiple data files (e.g., from a previous merge operation).
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -547,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
# Track source to destination file mapping for metadata update
# This is critical for handling datasets that are already results of a merge
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
@@ -559,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
data_idx = append_or_create_parquet_file(
# Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
df,
src_path,
data_idx,
@@ -571,10 +488,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
hf_features=hf_features,
)
# Record the mapping from source to actual destination
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
# Add the mapping to data_idx for use in metadata update
data_idx["src_to_dst"] = src_to_dst
return data_idx
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.
Reads source metadata files, updates all indices and timestamps,
@@ -586,7 +509,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
meta_idx: Dictionary tracking metadata chunk and file indices.
data_idx: Dictionary tracking data chunk and file indices.
videos_idx: Dictionary tracking video indices and timestamps.
audios_idx: Dictionary tracking audio indices and timestamps.
Returns:
dict: Updated meta_idx with current chunk and file indices.
@@ -610,10 +532,9 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
meta_idx,
data_idx,
videos_idx,
audios_idx,
)
meta_idx = append_or_create_parquet_file(
meta_idx, _ = append_or_create_parquet_file(
df,
src_path,
meta_idx,
@@ -627,8 +548,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
# Increment latest_duration by the total duration added from this source dataset
for k in videos_idx:
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
for k in audios_idx:
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
return meta_idx
@@ -642,7 +562,7 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
):
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
Manages file rotation when size limits are exceeded to prevent individual files
@@ -660,9 +580,11 @@ def append_or_create_parquet_file(
hf_features: Optional HuggingFace Features schema for proper image typing.
Returns:
dict: Updated index dictionary with current chunk and file indices.
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to.
"""
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
@@ -670,14 +592,15 @@ def append_or_create_parquet_file(
to_parquet_with_hf_images(df, dst_path, features=hf_features)
else:
df.to_parquet(dst_path)
return idx
return idx, (dst_chunk, dst_file)
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
@@ -696,7 +619,7 @@ def append_or_create_parquet_file(
else:
final_df.to_parquet(target_path)
return idx
return idx, (dst_chunk, dst_file)
def finalize_aggregation(aggr_meta, all_metadata):

View File

@@ -1,275 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import av
import torch
import torchaudio
import torchcodec
from numpy import ceil
CHANNELS_LAYOUTS_MAPPING = {
1: "mono",
2: "stereo",
3: "2.1",
4: "3.1",
5: "4.1",
6: "5.1",
7: "6.1",
8: "7.1",
16: "hexadecagonal",
24: "22.2",
}
def decode_audio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
start_time_s: float | None = 0.0,
backend: str | None = "torchcodec",
) -> torch.Tensor:
"""
Decodes audio using the specified backend.
Args:
audio_path (Path): Path to the audio file.
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
duration (float): Duration of the audio chunks in seconds.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
Returns:
torch.Tensor: Decoded audio chunks.
Currently supports torchaudio.
"""
if backend == "torchcodec":
return decode_audio_torchcodec(audio_path, timestamps, duration, start_time_s)
elif backend == "torchaudio":
return decode_audio_torchaudio(audio_path, timestamps, duration, start_time_s)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchcodec(
audio_path: Path | str,
timestamps: list[float],
duration: float,
start_time_s: float | None = 0.0,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
audio_sample_rate = audio_decoder.metadata.sample_rate
audio_channels = audio_decoder.metadata.num_channels
# TODO(CarolinePascal) : assert ts < total record duration
audio_chunks = []
timestamps = [
timestamp + start_time_s for timestamp in timestamps
] # Add an offset of start_time_s to each timestamp
for ts in timestamps:
current_audio_chunk = audio_decoder.get_samples_played_in_range(
start_seconds=max(0.0, ts - duration), stop_seconds=ts
)
current_audio_chunk_data = current_audio_chunk.data
# Case where the requested audio chunk starts before the beginning of the audio stream
if ts - duration < 0:
# No useful audio sample has been recorded
if ts < 1 / audio_sample_rate:
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.zeros(
(audio_channels, int(ceil(duration * audio_sample_rate)))
)
# At least one useful audio sample has been recorded
else:
# Pad the beginning of the audio chunk with zeros
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.nn.functional.pad(
current_audio_chunk_data,
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
)
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
)
audio_chunks.append(current_audio_chunk_data)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def decode_audio_torchaudio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
start_time_s: float | None = 0.0,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path)
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
audio_channels = reader.get_src_stream_info(reader.default_audio_stream).num_channels
# TODO(CarolinePascal) : assert ts < total record duration
# TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames
format="fltp", # Format as float32
)
audio_chunks = []
timestamps = [
timestamp + start_time_s for timestamp in timestamps
] # Add an offset of start_time_s to each timestamp
for ts in timestamps:
reader.seek(max(0.0, ts - duration)) # Default to closest audio sample. Needs to be non-negative !
status = reader.fill_buffer()
if status != 0:
# Should not happen, but just in case
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
current_audio_chunk = reader.pop_chunks()[0]
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
# Case where the requested audio chunk starts before the beginning of the audio stream
if ts - duration < 0:
# No useful audio sample has been recorded
if ts < 1 / audio_sample_rate:
current_audio_chunk_data = torch.zeros(
(audio_channels, int(ceil(duration * audio_sample_rate)))
)
# At least one useful audio sample has been recorded
else:
# Remove the superfluous last samples of the audio chunk
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
# Pad the beginning of the audio chunk with zeros
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.nn.functional.pad(
current_audio_chunk_data,
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
)
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
)
audio_chunks.append(current_audio_chunk_data)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def encode_audio(
input_path: Path | str,
output_path: Path | str,
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
bit_rate: int | None = None,
sample_rate: int | None = None,
log_level: int | None = av.logging.ERROR,
overwrite: bool = False,
) -> None:
"""Encodes an audio file using ffmpeg."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=overwrite)
# Set logging level
if log_level is not None:
# "While less efficient, it is generally preferable to modify logging with Pythons logging"
logging.getLogger("libav").setLevel(log_level)
# Open input file
with av.open(str(input_path), "r") as input:
input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded
# Define sub-sampling options
if sample_rate is None:
sample_rate = input_stream.rate
# Create and open output file (overwrite by default)
with av.open(str(output_path), "w") as output:
output_stream = output.add_stream(
codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels]
)
if bit_rate is not None:
output_stream.bit_rate = bit_rate
# Loop through input WAV packets and encode them
for input_frame in input.decode(
input_stream
): # This step handles both demuxing and decoding under the hood
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
# Flush the encoder
packet = output_stream.encode()
if packet:
output.mux(packet)
# Reset logging level
if log_level is not None:
av.logging.restore_default_callback()
if not output_path.exists():
raise OSError(f"Audio encoding did not work. File not found: {output_path}.")
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting audio stream information
audio_info = {}
with av.open(str(video_path), "r") as audio_file:
try:
audio_stream = audio_file.streams.audio[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {"has_audio": False}
audio_info["audio.channels"] = audio_stream.channels
audio_info["audio.codec"] = audio_stream.codec.canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info["audio.bit_rate"] = audio_stream.bit_rate
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info["audio.bit_depth"] = audio_stream.format.bits
audio_info["audio.channel_layout"] = audio_stream.layout.name
audio_info["has_audio"] = True
# Reset logging level
av.logging.restore_default_callback()
return audio_info

View File

@@ -15,7 +15,7 @@
# limitations under the License.
import numpy as np
from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy
from lerobot.datasets.utils import load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -245,20 +245,6 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def sample_audio_from_path(audio_path: str) -> np.ndarray:
"""Samples audio data from an audio recording stored in a WAV file."""
data = load_audio_from_path(audio_path)
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
"""Samples audio data from an audio recording stored in a numpy array."""
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def _reshape_stats_by_axis(
stats: dict[str, np.ndarray],
axis: int | tuple[int, ...] | None,
@@ -526,13 +512,6 @@ def compute_episode_stats(
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3)
keepdims = True
elif features[key]["dtype"] == "audio":
try:
ep_ft_array = sample_audio_from_path(data[0])
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
ep_ft_array = sample_audio_from_data(data)
axes_to_reduce = 0
keepdims = True
else:
ep_ft_array = data
axes_to_reduce = 0

View File

@@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
def modify_tasks(
dataset: LeRobotDataset,
new_task: str | None = None,
episode_tasks: dict[int, str] | None = None,
) -> LeRobotDataset:
"""Modify tasks in a LeRobotDataset.
This function allows you to either:
1. Set a single task for the entire dataset (using `new_task`)
2. Set specific tasks for specific episodes (using `episode_tasks`)
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
specific episodes.
The dataset is modified in-place, updating only the task-related files:
- meta/tasks.parquet
- data/**/*.parquet (task_index column)
- meta/episodes/**/*.parquet (tasks column)
- meta/info.json (total_tasks)
Args:
dataset: The source LeRobotDataset to modify.
new_task: A single task string to apply to all episodes. If None and episode_tasks
is also None, raises an error.
episode_tasks: Optional dict mapping episode indices to their task strings.
Overrides `new_task` for specific episodes.
Examples:
Set a single task for all episodes:
dataset = modify_tasks(dataset, new_task="Pick up the cube")
Set different tasks for specific episodes:
dataset = modify_tasks(
dataset,
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
)
Set a default task with overrides:
dataset = modify_tasks(
dataset,
new_task="Default task",
episode_tasks={5: "Special task for episode 5"}
)
"""
if new_task is None and episode_tasks is None:
raise ValueError("Must specify at least one of new_task or episode_tasks")
if episode_tasks is not None:
valid_indices = set(range(dataset.meta.total_episodes))
invalid = set(episode_tasks.keys()) - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")
# Ensure episodes metadata is loaded
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.root)
# Build the mapping from episode index to task string
episode_to_task: dict[int, str] = {}
for ep_idx in range(dataset.meta.total_episodes):
if episode_tasks and ep_idx in episode_tasks:
episode_to_task[ep_idx] = episode_tasks[ep_idx]
elif new_task is not None:
episode_to_task[ep_idx] = new_task
else:
# Keep original task if not overridden and no default provided
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
if not original_tasks:
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
episode_to_task[ep_idx] = original_tasks[0]
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
logging.info(f"New tasks: {unique_tasks}")
root = dataset.root
# Update data files - modify task_index column
logging.info("Updating data files...")
data_dir = root / DATA_DIR
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
df = pd.read_parquet(parquet_path)
# Build a mapping from episode_index to new task_index for rows in this file
episode_indices_in_file = df["episode_index"].unique()
ep_to_new_task_idx = {
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
}
# Update task_index column
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
df.to_parquet(parquet_path, index=False)
# Update episodes metadata - modify tasks column
logging.info("Updating episodes metadata...")
episodes_dir = root / "meta" / "episodes"
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
df = pd.read_parquet(parquet_path)
# Update tasks column
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
df.to_parquet(parquet_path, index=False)
# Write new tasks.parquet
write_tasks(new_task_df, root)
# Update info.json
dataset.meta.info["total_tasks"] = len(unique_tasks)
write_info(dataset.meta.info, root)
# Reload metadata to reflect changes
dataset.meta.tasks = new_task_df
dataset.meta.episodes = load_episodes(root)
logging.info(f"Tasks: {unique_tasks}")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,

View File

@@ -33,16 +33,12 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.datasets.audio_utils import decode_audio, encode_audio, get_audio_info
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
DEFAULT_RAW_AUDIO_PATH,
INFO_PATH,
_validate_feature_names,
check_delta_timestamps,
@@ -61,6 +57,7 @@ from lerobot.datasets.utils import (
load_info,
load_nested_dataset,
load_stats,
load_subtasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
@@ -72,15 +69,13 @@ from lerobot.datasets.utils import (
)
from lerobot.datasets.video_utils import (
VideoFrame,
concatenate_media_files,
concatenate_video_files,
decode_video_frames,
encode_video_frames,
get_media_duration_in_s,
get_safe_default_codec,
get_video_duration_in_s,
get_video_info,
)
from lerobot.microphones import Microphone
from lerobot.microphones.utils import async_microphones_start_recording
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
@@ -168,6 +163,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -220,19 +216,6 @@ class LeRobotDatasetMetadata:
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep[f"audio/{audio_key}/chunk_index"]
file_idx = ep[f"audio/{audio_key}/file_index"]
fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
@@ -243,11 +226,6 @@ class LeRobotDatasetMetadata:
"""Formattable string for the video files."""
return self.info["video_path"]
@property
def audio_path(self) -> str | None:
"""Formattable string for the audio files."""
return self.info["audio_path"]
@property
def robot_type(self) -> str | None:
"""Robot type used in recording this dataset."""
@@ -278,11 +256,6 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def audio_keys(self) -> list[str]:
"""Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@@ -323,11 +296,6 @@ class LeRobotDatasetMetadata:
"""Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"]
@property
def audio_files_size_in_mb(self) -> int:
"""Max size of audio file in mega bytes."""
return self.info["audio_files_size_in_mb"]
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
@@ -469,27 +437,11 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_audio_info(self, audio_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
if audio_key is not None and audio_key not in self.audio_keys:
raise ValueError(f"Audio key {audio_key} not found in dataset")
audio_keys = [audio_key] if audio_key is not None else self.audio_keys
for key in audio_keys:
if not self.features[key].get("info", None):
audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_audio_info(audio_path)
self.info["features"][key]["info"]["start_time_s"] = DEFAULT_INITIAL_AUDIO_BUFFER_DURATION
def update_chunk_settings(
self,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
audio_files_size_in_mb: int | None = None,
) -> None:
"""Update chunk and file size settings after dataset creation.
@@ -501,7 +453,6 @@ class LeRobotDatasetMetadata:
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value.
"""
if chunks_size is not None:
if chunks_size <= 0:
@@ -518,11 +469,6 @@ class LeRobotDatasetMetadata:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb
if audio_files_size_in_mb is not None:
if audio_files_size_in_mb <= 0:
raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}")
self.info["audio_files_size_in_mb"] = audio_files_size_in_mb
# Update the info file on disk
write_info(self.info, self.root)
@@ -530,13 +476,12 @@ class LeRobotDatasetMetadata:
"""Get current chunk and file size settings.
Returns:
Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb.
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
"""
return {
"chunks_size": self.chunks_size,
"data_files_size_in_mb": self.data_files_size_in_mb,
"video_files_size_in_mb": self.video_files_size_in_mb,
"audio_files_size_in_mb": self.audio_files_size_in_mb,
}
def __repr__(self):
@@ -563,7 +508,6 @@ class LeRobotDatasetMetadata:
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
audio_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
@@ -576,6 +520,7 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
@@ -587,7 +532,6 @@ class LeRobotDatasetMetadata:
chunks_size,
data_files_size_in_mb,
video_files_size_in_mb,
audio_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
@@ -623,9 +567,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
download_audio: bool = True,
video_backend: str | None = None,
audio_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
):
@@ -659,7 +601,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
task-conditioned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
- audio (optional) from which audio is loaded to be synchronous with data from parquet files.
A typical LeRobotDataset looks like this from its root path:
.
@@ -685,37 +626,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
│ ├── info.json
│ ├── stats.json
│ └── tasks.parquet
── videos
├── observation.images.laptop
│ │ ├── chunk-000
│ │ │ ├── file-000.mp4
│ │ │ ├── file-001.mp4
│ │ │ └── ...
│ │ ├── chunk-001
│ │ │ └── ...
│ │ └── ...
│ ├── observation.images.phone
│ │ ├── chunk-000
│ │ │ ├── file-000.mp4
│ │ │ ├── file-001.mp4
│ │ │ └── ...
│ │ ├── chunk-001
│ │ │ └── ...
│ │ └── ...
│ └── ...
└── audio
├── observation.audio.laptop
── videos
├── observation.images.laptop
│ ├── chunk-000
│ │ ├── file-000.m4a
│ │ ├── file-001.m4a
│ │ ├── file-000.mp4
│ │ ├── file-001.mp4
│ │ └── ...
│ ├── chunk-001
│ │ └── ...
│ └── ...
├── observation.audio.phone
├── observation.images.phone
│ ├── chunk-000
│ │ ├── file-000.m4a
│ │ ├── file-001.m4a
│ │ ├── file-000.mp4
│ │ ├── file-001.mp4
│ │ └── ...
│ ├── chunk-001
│ │ └── ...
@@ -755,10 +678,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
download_audio (bool, optional): Flag to download the audio. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchcodec'.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
@@ -776,9 +697,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.audio_backend = (
audio_backend if audio_backend else "torchcodec"
) # Waiting for torchcodec release #TODO(CarolinePascal)
self.delta_indices = None
self.batch_encoding_size = batch_encoding_size
self.episodes_since_last_encoding = 0
@@ -851,7 +769,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
push_audio: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
@@ -860,8 +777,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
if not push_audio:
ignore_patterns.append("audio/")
hub_api = HfApi()
hub_api.create_repo(
@@ -916,7 +831,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns=ignore_patterns,
)
def download(self, download_videos: bool = True, download_audio: bool = True) -> None:
def download(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@@ -924,12 +839,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
ignore_patterns = None if download_videos else "videos/"
files = None
ignore_patterns = []
if not download_videos:
ignore_patterns.append("videos/")
if not download_audio:
ignore_patterns.append("audio/")
if self.episodes is not None:
files = self.get_episodes_file_paths()
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
@@ -944,15 +855,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
for ep_idx in episodes
]
fpaths += video_files
if len(self.meta.audio_keys) > 0:
audio_files = [
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
for audio_key in self.meta.audio_keys
for ep_idx in episodes
]
fpaths += audio_files
# episodes are stored in the same files, so we return unique paths only
fpaths = list(set(fpaths))
return fpaths
@@ -965,7 +867,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video and audio files."""
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
@@ -993,14 +895,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not video_path.exists():
return False
# Check if all required audio files exist
if len(self.meta.audio_keys) > 0:
for ep_idx in requested_episodes:
for audio_key in self.meta.audio_keys:
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
if not audio_path.exists():
return False
return True
def create_hf_dataset(self) -> datasets.Dataset:
@@ -1079,7 +973,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.meta.video_keys + self.meta.audio_keys:
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
@@ -1094,7 +988,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
"""
Query dataset for indices across keys, skipping video keys and audio keys.
Query dataset for indices across keys, skipping video keys.
Tries column-first [key][indices] for speed, falls back to row-first.
@@ -1106,7 +1000,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self.meta.video_keys or key in self.meta.audio_keys:
if key in self.meta.video_keys:
continue
# Map absolute indices to relative indices if needed
relative_indices = (
@@ -1141,28 +1035,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
# TODO(CarolinePascal): add variable query durations
def _query_audio(
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
) -> dict[str, torch.Tensor]:
ep = self.meta.episodes[ep_idx]
item = {}
for audio_key, query_ts in query_timestamps.items():
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
# Thus we load the start timestamp of the episode on this mp4 and,
# shift the query timestamp accordingly.
from_timestamp = ep[f"audio/{audio_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
start_time_s = self.meta.features[audio_key]["info"].get("start_time_s", 0.0)
audio_chunk = decode_audio(
audio_path, shifted_query_ts, query_duration, start_time_s, self.audio_backend
)
item[audio_key] = audio_chunk.squeeze(0)
return item
def _ensure_hf_dataset_loaded(self):
"""Lazy load the HF dataset only when needed for reading."""
if self._lazy_loading or self.hf_dataset is None:
@@ -1192,12 +1064,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, val in query_result.items():
item[key] = val
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
item = {**item, **video_frames, **audio_chunks}
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
@@ -1207,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self.features and self.meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
return item
def __repr__(self):
@@ -1245,10 +1122,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
return self.root / fpath
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
return self.root / fpath
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
@@ -1301,43 +1174,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
elif self.features[key]["dtype"] == "audio":
if (
self.meta.robot_type == "lekiwi"
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
self.episode_buffer[key].append(frame[key])
else: # Otherwise, only the audio file path is stored in the episode buffer
if frame_index == 0:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key
)
self.episode_buffer[key].append(str(audio_path))
else:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None:
"""
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
"""
audio_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
microphone.start_recording(output_file=audio_file)
def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None:
"""
Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files.
"""
output_files = []
for microphone_key in microphones:
output_files.append(
self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
)
async_microphones_start_recording(microphones, output_files)
def save_episode(
self,
episode_data: dict | None = None,
@@ -1381,12 +1222,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif ft["dtype"] == "audio":
if (
self.meta.robot_type == "lekiwi"
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
continue
episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
@@ -1395,10 +1230,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_metadata = self._save_episode_data(episode_buffer)
has_video_keys = len(self.meta.video_keys) > 0
has_audio_keys = len(self.meta.audio_keys) > 0
use_batched_encoding = self.batch_encoding_size > 1
if (has_video_keys or has_audio_keys) and not use_batched_encoding:
if has_video_keys and not use_batched_encoding:
num_cameras = len(self.meta.video_keys)
if parallel_encoding and num_cameras > 1:
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
@@ -1435,30 +1269,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# TODO(Caroline): add parallel encoding for audio as well
for audio_key in self.meta.audio_keys:
ep_metadata.update(self._save_episode_audio(audio_key, episode_index))
# `meta.save_episode` need to be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
if (has_video_keys or has_audio_keys) and use_batched_encoding:
if has_video_keys and use_batched_encoding:
# Check if we should trigger batch encoding
self.episodes_since_last_encoding += 1
if self.episodes_since_last_encoding == self.batch_encoding_size:
start_ep = self.num_episodes - self.batch_encoding_size
end_ep = self.num_episodes
if has_video_keys:
self._batch_save_episode_video(start_ep, end_ep)
if has_audio_keys:
self._batch_save_episode_audio(start_ep, end_ep)
self._batch_save_episode_video(start_ep, end_ep)
self.episodes_since_last_encoding = 0
if not episode_data:
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
self.clear_episode_buffer(
delete_images=len(self.meta.image_keys) > 0, delete_audio=len(self.meta.audio_keys) > 0
)
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
"""
@@ -1509,70 +1334,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
dtype_backend="pyarrow"
) # allows NaN values along with integers
# Save the current episode's audio metadata to the dataframe
audio_ep_metadata = {}
for audio_key in self.meta.audio_keys:
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
audio_ep_metadata.pop("episode_index")
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
dtype_backend="pyarrow"
) # allows NaN values along with integers
episode_df = episode_df.combine_first(video_ep_df)
episode_df = episode_df.combine_first(audio_ep_df)
episode_df.to_parquet(episode_df_path)
self.meta.episodes = load_episodes(self.root)
def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None:
"""
Batch save audio for multiple episodes.
Args:
start_episode: Starting episode index (inclusive)
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
"""
if end_episode is None:
end_episode = self.num_episodes
logging.info(
f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}"
)
chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"]
file_idx = self.meta.episodes[start_episode]["data/file_index"]
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
episode_df = pd.read_parquet(episode_df_path)
for ep_idx in range(start_episode, end_episode):
logging.info(f"Encoding audio for episode {ep_idx}")
if (
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
or self.meta.episodes[ep_idx]["data/file_index"] != file_idx
):
# The current episode is in a new chunk or file.
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
episode_df.to_parquet(episode_df_path)
self.meta.episodes = load_episodes(self.root)
# Load new episode dataframe
chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"]
file_idx = self.meta.episodes[ep_idx]["data/file_index"]
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)
episode_df = pd.read_parquet(episode_df_path)
# Save the current episode's video metadata to the dataframe
audio_ep_metadata = {}
for audio_key in self.meta.audio_keys:
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
audio_ep_metadata.pop("episode_index")
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
dtype_backend="pyarrow"
) # allows NaN values along with integers
episode_df = episode_df.combine_first(audio_ep_df)
episode_df.to_parquet(episode_df_path)
self.meta.episodes = load_episodes(self.root)
@@ -1683,7 +1445,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_path = temp_path
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
ep_duration_in_s = get_video_duration_in_s(ep_path)
if (
episode_index == 0
@@ -1729,7 +1491,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
latest_duration_in_s = 0.0
else:
# Update latest video file
concatenate_media_files(
concatenate_video_files(
[latest_path, ep_path],
latest_path,
)
@@ -1751,79 +1513,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
}
return metadata
def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict:
# Encode episode audio into a temporary audio file
ep_path = self._encode_temporary_episode_audio(audio_key, episode_index)
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
if (
episode_index == 0
or self.meta.latest_episode is None
or f"audio/{audio_key}/chunk_index" not in self.meta.latest_episode
):
# Initialize indices for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
old_chunk_idx = self.meta.episodes[-1][f"audio/{audio_key}/chunk_index"]
old_file_idx = self.meta.episodes[-1][f"audio/{audio_key}/file_index"]
chunk_idx, file_idx = update_chunk_file_indices(
old_chunk_idx, old_file_idx, self.meta.chunks_size
)
latest_duration_in_s = 0.0
new_path = self.root / self.meta.audio_path.format(
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
else:
# Retrieve information from the latest updated audio file using latest_episode
latest_ep = self.meta.latest_episode
chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0]
file_idx = latest_ep[f"audio/{audio_key}/file_index"][0]
latest_path = self.root / self.meta.audio_path.format(
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
)
latest_size_in_mb = get_file_size_in_mb(latest_path)
latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0]
if latest_size_in_mb + ep_size_in_mb >= self.meta.audio_files_size_in_mb:
# Move temporary episode audio to a new audio file in the dataset
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.root / self.meta.audio_path.format(
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
latest_duration_in_s = 0.0
else:
# Update latest audio file
concatenate_media_files(
[latest_path, ep_path],
latest_path,
)
# Remove temporary directory
shutil.rmtree(str(ep_path.parent))
# Update audio info (only needed when first episode is encoded since it reads from episode 0)
if episode_index == 0:
self.meta.update_audio_info(audio_key)
write_info(self.meta.info, self.meta.root) # ensure audio info always written properly
metadata = {
"episode_index": episode_index,
f"audio/{audio_key}/chunk_index": chunk_idx,
f"audio/{audio_key}/file_index": file_idx,
f"audio/{audio_key}/from_timestamp": latest_duration_in_s,
f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
}
return metadata
def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None:
def clear_episode_buffer(self, delete_images: bool = True) -> None:
# Clean up image files for the current episode buffer
if delete_images:
# Wait for the async image writer to finish
@@ -1837,16 +1527,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
if img_dir.is_dir():
shutil.rmtree(img_dir)
# Clean up audio files for the current episode buffer
if delete_audio:
episode_index = self.episode_buffer["episode_index"]
if isinstance(episode_index, np.ndarray):
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for audio_key in self.meta.audio_keys:
audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
if audio_file.is_file():
audio_file.unlink()
# Reset the buffer
self.episode_buffer = self.create_episode_buffer()
@@ -1883,18 +1563,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path:
"""
Use ffmpeg to convert raw audio files into m4a audio files.
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since audio encoding with ffmpeg is already using multithreading.
"""
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{audio_key}_{episode_index:03d}.m4a"
raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
encode_audio(raw_audio_file, temp_path, overwrite=True)
raw_audio_file.unlink()
return temp_path
@classmethod
def create(
cls,
@@ -1908,7 +1576,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_processes: int = 0,
image_writer_threads: int = 0,
video_backend: str | None = None,
audio_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
) -> "LeRobotDataset":
@@ -1953,9 +1620,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._lazy_loading = False
obj._recorded_frames = 0
obj._writer_closed_for_reading = False
obj.audio_backend = (
audio_backend if audio_backend is not None else "torchcodec"
) # Waiting for torchcodec release #TODO(CarolinePascal)
return obj
@@ -1976,7 +1640,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
audio_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
@@ -1994,7 +1657,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
audio_backend=audio_backend,
)
for repo_id in repo_ids
]

View File

@@ -216,16 +216,17 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
if cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
class ImageTransforms(Transform):

View File

@@ -36,7 +36,6 @@ from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from soundfile import read
from torchvision import transforms
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -51,7 +50,6 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file
INFO_PATH = "meta/info.json"
STATS_PATH = "meta/stats.json"
@@ -59,19 +57,14 @@ STATS_PATH = "meta/stats.json"
EPISODES_DIR = "meta/episodes"
DATA_DIR = "data"
VIDEO_DIR = "videos"
AUDIO_DIR = "audio"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav"
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION = 1.0 # seconds
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
@@ -361,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -416,16 +417,6 @@ def load_image_as_numpy(
return img_array
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
audio_data, _ = read(fpath, dtype="float32")
# Fill missing channel dimension when loading mono audio data
if audio_data.ndim == 1:
audio_data = np.expand_dims(audio_data, axis=1)
return audio_data
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
@@ -594,7 +585,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video" or ft["dtype"] == "audio":
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -657,12 +648,7 @@ def hw_to_dataset_features(
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
cam_fts = {
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 3
}
mic_fts = {
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 2
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
features[prefix] = {
@@ -685,14 +671,6 @@ def hw_to_dataset_features(
"names": ["height", "width", "channels"],
}
for key, parameters in mic_fts.items():
features[f"{prefix}.audio.{key}"] = {
"dtype": "audio",
"shape": (len(parameters[1]),),
"names": ["channels"],
"info": {"sample_rate": parameters[0]},
}
_validate_feature_names(features)
return features
@@ -722,8 +700,6 @@ def build_dataset_frame(
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
elif ft["dtype"] == "audio":
frame[key] = values[key.removeprefix(f"{prefix}.audio.")]
return frame
@@ -757,10 +733,6 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif ft["dtype"] == "audio":
type = FeatureType.AUDIO
if len(shape) != 2:
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):
@@ -839,7 +811,6 @@ def create_empty_dataset_info(
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
audio_files_size_in_mb: int | None = None,
) -> dict:
"""Create a template dictionary for a new dataset's `info.json`.
@@ -849,10 +820,6 @@ def create_empty_dataset_info(
features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any.
chunks_size (int | None): The maximum number of files per chunk directory.
data_files_size_in_mb (int | None): The maximum size for data files in MB.
video_files_size_in_mb (int | None): The maximum size for video files in MB.
audio_files_size_in_mb (int | None): The maximum size for audio files in MB.
Returns:
dict: A dictionary with the initial dataset metadata.
@@ -866,12 +833,10 @@ def create_empty_dataset_info(
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB,
"fps": fps,
"splits": {},
"data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"audio_path": DEFAULT_AUDIO_PATH,
"features": features,
}
@@ -1095,8 +1060,6 @@ def validate_feature_dtype_and_shape(
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "audio":
return validate_feature_audio(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
@@ -1163,23 +1126,6 @@ def validate_feature_image_or_video(
return error_message
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c = expected_shape
if (len(actual_shape) != 2 and len(actual_shape) != 1) or actual_shape[-1] != c[
-1
]: # The number of frames might be different
error_message += (
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n"
)
else:
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str) -> str:
"""Validate a feature that is expected to be a string.

View File

@@ -59,8 +59,6 @@ from requests import HTTPError
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
DEFAULT_AUDIO_PATH,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -81,7 +79,7 @@ from lerobot.datasets.utils import (
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
@@ -313,12 +311,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Check if adding this episode would exceed the limit
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
# Size limit would be exceeded, save current accumulation WITHOUT this episode
concatenate_media_files(
concatenate_video_files(
paths_to_cat,
new_root
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
@@ -354,7 +352,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
# Write remaining videos if any
if paths_to_cat:
concatenate_media_files(
concatenate_video_files(
paths_to_cat,
new_root
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
@@ -369,124 +367,8 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
return episodes_metadata
def get_audio_keys(root):
info = load_info(root)
features = info["features"]
audio_keys = [key for key, ft in features.items() if ft["dtype"] == "audio"]
return audio_keys
def convert_audios(root: Path, new_root: Path, audio_file_size_in_mb: int):
logging.info(f"Converting audios from {root} to {new_root}")
audio_keys = get_audio_keys(root)
if len(audio_keys) == 0:
return None
audio_keys = sorted(audio_keys)
eps_metadata_per_mic = []
for microphone in audio_keys:
eps_metadata = convert_audios_of_microphone(root, new_root, microphone, audio_file_size_in_mb)
eps_metadata_per_mic.append(eps_metadata)
num_eps_per_mic = [len(eps_mic_map) for eps_mic_map in eps_metadata_per_mic]
if len(set(num_eps_per_mic)) != 1:
raise ValueError(f"All microphones dont have same number of episodes ({num_eps_per_mic}).")
episodes_metadata = []
num_microphones = len(audio_keys)
num_episodes = num_eps_per_mic[0]
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert audios"):
# Sanity check
ep_ids = [
eps_metadata_per_mic[mic_idx][ep_idx]["episode_index"] for mic_idx in range(num_microphones)
]
ep_ids += [ep_idx]
if len(set(ep_ids)) != 1:
raise ValueError(f"All episode indices need to match ({ep_ids}).")
ep_dict = {}
for mic_idx in range(num_microphones):
ep_dict.update(eps_metadata_per_mic[mic_idx][ep_idx])
episodes_metadata.append(ep_dict)
return episodes_metadata
def convert_audios_of_microphone(root: Path, new_root: Path, audio_key: str, audio_file_size_in_mb: int):
# Access old paths to m4a
audios_dir = root / "audio"
ep_paths = sorted(audios_dir.glob(f"*/{audio_key}/*.m4a"))
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
episodes_metadata = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert audios of {audio_key}"):
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
# Check if adding this episode would exceed the limit
if size_in_mb + ep_size_in_mb >= audio_file_size_in_mb and len(paths_to_cat) > 0:
# Size limit would be exceeded, save current accumulation WITHOUT this episode
concatenate_media_files(
paths_to_cat,
new_root
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
)
# Update episodes metadata for the file we just saved
for i, _ in enumerate(paths_to_cat):
past_ep_idx = ep_idx - len(paths_to_cat) + i
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
# Move to next file and start fresh with current episode
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
size_in_mb = 0
duration_in_s = 0.0
paths_to_cat = []
# Add current episode metadata
ep_metadata = {
"episode_index": ep_idx,
f"audio/{audio_key}/chunk_index": chunk_idx, # Will be updated when file is saved
f"audio/{audio_key}/file_index": file_idx, # Will be updated when file is saved
f"audio/{audio_key}/from_timestamp": duration_in_s,
f"audio/{audio_key}/to_timestamp": duration_in_s + ep_duration_in_s,
}
episodes_metadata.append(ep_metadata)
# Add current episode to accumulation
paths_to_cat.append(ep_path)
size_in_mb += ep_size_in_mb
duration_in_s += ep_duration_in_s
ep_idx += 1
# Write remaining videos if any
if paths_to_cat:
concatenate_media_files(
paths_to_cat,
new_root
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
)
# Update episodes metadata for the final file
for i, _ in enumerate(paths_to_cat):
past_ep_idx = ep_idx - len(paths_to_cat) + i
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
return episodes_metadata
def generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None, episodes_audios=None
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
):
num_episodes = len(episodes_metadata)
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
@@ -510,30 +392,16 @@ def generate_episode_metadata_dict(
ep_video = episodes_videos[i]
ep_ids_set.add(ep_video["episode_index"])
if episodes_audios is None:
ep_audio = {}
else:
ep_audio = episodes_audios[i]
ep_ids_set.add(ep_audio["episode_index"])
if len(ep_ids_set) != 1:
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
ep_dict = {
**ep_metadata,
**ep_video,
**ep_audio,
**ep_legacy_metadata,
**flatten_dict({"stats": ep_stats}),
}
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
ep_dict["meta/episodes/chunk_index"] = 0
ep_dict["meta/episodes/file_index"] = 0
yield ep_dict
def convert_episodes_metadata(
root, new_root, episodes_metadata, episodes_video_metadata=None, episodes_audio_metadata=None
):
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
logging.info(f"Converting episodes metadata from {root} to {new_root}")
episodes_legacy_metadata = legacy_load_episodes(root)
@@ -542,19 +410,13 @@ def convert_episodes_metadata(
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
if episodes_video_metadata is not None:
num_eps_set.add(len(episodes_video_metadata))
if episodes_audio_metadata is not None:
num_eps_set.add(len(episodes_audio_metadata))
if len(num_eps_set) != 1:
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
ds_episodes = Dataset.from_generator(
lambda: generate_episode_metadata_dict(
episodes_legacy_metadata,
episodes_metadata,
episodes_stats,
episodes_video_metadata,
episodes_audio_metadata,
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
)
)
write_episodes(ds_episodes, new_root)
@@ -563,22 +425,20 @@ def convert_episodes_metadata(
write_stats(stats, new_root)
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb, audio_file_size_in_mb):
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
info = load_info(root)
info["codebase_version"] = V30
del info["total_chunks"]
del info["total_videos"]
info["data_files_size_in_mb"] = data_file_size_in_mb
info["video_files_size_in_mb"] = video_file_size_in_mb
info["audio_files_size_in_mb"] = audio_file_size_in_mb
info["data_path"] = DEFAULT_DATA_PATH
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
info["audio_path"] = DEFAULT_AUDIO_PATH if info["audio_path"] is not None else None
info["fps"] = int(info["fps"])
logging.info(f"Converting info from {root} to {new_root}")
for key in info["features"]:
if info["features"][key]["dtype"] == "video" or info["features"][key]["dtype"] == "audio":
# already has fps in video_info or audio_info
if info["features"][key]["dtype"] == "video":
# already has fps in video_info
continue
info["features"][key]["fps"] = info["fps"]
write_info(info, new_root)
@@ -589,7 +449,6 @@ def convert_dataset(
branch: str | None = None,
data_file_size_in_mb: int | None = None,
video_file_size_in_mb: int | None = None,
audio_file_size_in_mb: int | None = None,
root: str | Path | None = None,
push_to_hub: bool = True,
force_conversion: bool = False,
@@ -598,8 +457,6 @@ def convert_dataset(
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_file_size_in_mb is None:
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if audio_file_size_in_mb is None:
audio_file_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
# First check if the dataset already has a v3.0 version
if root is None and not force_conversion:
@@ -641,10 +498,7 @@ def convert_dataset(
convert_tasks(root, new_root)
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
episodes_audios_metadata = convert_audios(root, new_root, audio_file_size_in_mb)
convert_episodes_metadata(
root, new_root, episodes_metadata, episodes_videos_metadata, episodes_audios_metadata
)
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
shutil.move(str(root), str(old_root))
shutil.move(str(new_root), str(root))
@@ -657,7 +511,7 @@ def convert_dataset(
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
pass
hub_api.delete_files(
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*", "audio/chunk*"],
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
repo_id=repo_id,
revision=branch,
repo_type="dataset",
@@ -695,12 +549,6 @@ if __name__ == "__main__":
default=None,
help="File size in MB. Defaults to 100 for data and 500 for videos.",
)
parser.add_argument(
"--audio-file-size-in-mb",
type=int,
default=None,
help="File size in MB. Defaults to 100 for audio.",
)
parser.add_argument(
"--root",
type=str,

View File

@@ -397,42 +397,42 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def concatenate_media_files(
input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True
def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
):
"""
Concatenate multiple media files (video & audio) into a single media file using pyav.
Concatenate multiple video files into a single video file using pyav.
This function takes a list of input media file paths and concatenates them into a single
output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast
This function takes a list of video input file paths and concatenates them into a single
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
concatenation without re-encoding.
Args:
input_media_paths: Ordered list of input media file paths to concatenate.
output_media_path: Path to the output media file.
overwrite: Whether to overwrite the output media file if it already exists. Default is True.
input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
Note:
- Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use.
- Uses ffmpeg's concat demuxer which requires all input media files to have the same
- Creates a temporary directory for intermediate files that is cleaned up after use.
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
codec, resolution, and frame rate for proper concatenation.
"""
output_media_path = Path(output_media_path)
output_video_path = Path(output_video_path)
if output_media_path.exists() and not overwrite:
logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.")
if output_video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
return
output_media_path.parent.mkdir(parents=True, exist_ok=True)
output_video_path.parent.mkdir(parents=True, exist_ok=True)
if len(input_media_paths) == 0:
raise FileNotFoundError("No input media paths provided.")
if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.")
# Create a temporary .ffconcat file to list the input media paths
# Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
for input_path in input_media_paths:
for input_path in input_video_paths:
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
tmp_concatenate_file.flush()
tmp_concatenate_path = tmp_concatenate_file.name
@@ -442,12 +442,11 @@ def concatenate_media_files(
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
) # safe = 0 allows absolute paths as well as relative paths
# Using an intermediate container to store the concatenated media file is necessary to avoid inplace concatenation read-write race conditions.
with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file:
tmp_output_media_path = tmp_named_file.name
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
output_container = av.open(
tmp_output_media_path, mode="w", options={"movflags": "faststart"}
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
) # faststart is to move the metadata to the beginning of the file to speed up loading
# Replicate input streams in output container
@@ -462,7 +461,6 @@ def concatenate_media_files(
stream_map[input_stream.index].time_base = input_stream.time_base
# Demux + remux packets (no re-encode)
last_dts = None
for packet in input_container.demux():
# Skip packets from un-mapped streams
if packet.stream.index not in stream_map:
@@ -471,16 +469,6 @@ def concatenate_media_files(
# Skip demux flushing packets
if packet.dts is None:
continue
else:
# Enforce strictly increasing decoding timestamps (DTS)
if last_dts is not None and packet.dts <= last_dts:
shift = last_dts - packet.dts + 1
packet.dts += shift
packet.pts += shift # Presenting timestamps (PTS) are the same as DTS here
logging.warning(
f"Non-monotonic DTS; previous: {last_dts}, current: {packet.dts - shift}; changing to {packet.dts}. This may result in incorrect timestamps in the output file."
)
last_dts = packet.dts
output_stream = stream_map[packet.stream.index]
packet.stream = output_stream
@@ -488,7 +476,7 @@ def concatenate_media_files(
input_container.close()
output_container.close()
shutil.move(tmp_output_media_path, output_media_path)
shutil.move(tmp_output_video_path, output_video_path)
Path(tmp_concatenate_path).unlink()
@@ -524,6 +512,38 @@ with warnings.catch_warnings():
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting audio stream information
audio_info = {}
with av.open(str(video_path), "r") as audio_file:
try:
audio_stream = audio_file.streams.audio[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {"has_audio": False}
audio_info["audio.channels"] = audio_stream.channels
audio_info["audio.codec"] = audio_stream.codec.canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info["audio.bit_rate"] = audio_stream.bit_rate
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info["audio.bit_depth"] = audio_stream.format.bits
audio_info["audio.channel_layout"] = audio_stream.layout.name
audio_info["has_audio"] = True
# Reset logging level
av.logging.restore_default_callback()
return audio_info
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
@@ -553,6 +573,9 @@ def get_video_info(video_path: Path | str) -> dict:
# Reset logging level
av.logging.restore_default_callback()
# Adding audio stream information
video_info.update(**get_audio_info(video_path))
return video_info
@@ -567,22 +590,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
raise ValueError("Unknown format")
def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float:
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a media file (video & audio) in seconds using PyAV.
Get the duration of a video file in seconds using PyAV.
Args:
media_path: Path to the media file.
video_path: Path to the video file.
Returns:
Duration of the media file in seconds.
Duration of the video in seconds.
"""
with av.open(str(media_path)) as container:
# Get the first stream
stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0]
with av.open(str(video_path)) as container:
# Get the first video stream
video_stream = container.streams.video[0]
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
if stream.duration is not None:
duration = float(stream.duration * stream.time_base)
if video_stream.duration is not None:
duration = float(video_stream.duration * video_stream.time_base)
else:
# Fallback to container duration if stream duration is not available
duration = float(container.duration / av.time_base)
@@ -591,12 +614,12 @@ def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -
class VideoEncodingManager:
"""
Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur.
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
This manager handles:
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image and audio files from interrupted episodes
- Removing empty image and audio directories
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args:
dataset: The LeRobotDataset instance
@@ -623,7 +646,6 @@ class VideoEncodingManager:
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset._batch_save_episode_video(start_ep, end_ep)
self.dataset._batch_save_episode_audio(start_ep, end_ep)
# Finalize the dataset to properly close all writers
self.dataset.finalize()
@@ -640,15 +662,6 @@ class VideoEncodingManager:
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
for key in self.dataset.meta.audio_keys:
audio_file = self.dataset._get_raw_audio_file_path(
episode_index=interrupted_episode_index, audio_key=key
)
if audio_file.exists():
logging.debug(
f"Cleaning up interrupted episode audio for episode {interrupted_episode_index}, microphone {key}"
)
audio_file.unlink()
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"
@@ -662,16 +675,4 @@ class VideoEncodingManager:
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
# Clean up any remaining audio directory if it's empty
audio_dir = self.dataset.root / "raw_audio"
# Check for any remaining WAV files
wav_files = list(audio_dir.rglob("*.wav"))
if len(wav_files) == 0:
# Only remove the raw_audio directory if no WAV files remain
if audio_dir.exists():
shutil.rmtree(audio_dir)
logging.debug("Cleaned up empty audio directory")
else:
logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files")
return False # Don't suppress the original exception

View File

@@ -205,6 +205,7 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
@@ -260,6 +261,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
@dataclass
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
task_ids: list[int] | None = None
fps: int = 30
episode_length: int | None = None
obs_type: str = "pixels_agent_pos"
@@ -338,10 +340,10 @@ class LiberoEnv(EnvConfig):
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
if self.task_ids is not None:
kwargs["task_ids"] = self.task_ids
return kwargs
@EnvConfig.register_subclass("metaworld")

View File

@@ -112,6 +112,7 @@ class LiberoEnv(gym.Env):
visualization_height: int = 480,
init_states: bool = True,
episode_index: int = 0,
n_envs: int = 1,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
@@ -145,7 +146,9 @@ class LiberoEnv(gym.Env):
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
@@ -295,7 +298,8 @@ class LiberoEnv(gym.Env):
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
self.init_state_id += self._reset_stride # Change init_state_id when reset
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
@@ -373,6 +377,7 @@ def _make_env_fns(
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
n_envs=n_envs,
control_mode=control_mode,
**local_kwargs,
)

View File

@@ -1,45 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import MicrophoneConfig
@MicrophoneConfig.register_subclass("anyskin")
@dataclass
class AnyskinSensorConfig(MicrophoneConfig):
"""Configuration class for Anyskin tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
This class provides configuration options for Anyskin tactile sensors, including serial port, sample rate and channels.
Example configurations:
```python
# Basic configurations
AnyskinSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
AnyskinSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
```
Attributes:
sensor_port: Serial port of the tactile sensor.
baud_rate: Baud rate of the tactile sensor.
sample_rate: Sample rate in Hz for the tactile sensor.
channels: List of channel numbers to use for the tactile sensor.
"""
sensor_port: str
baud_rate: int = 115_200
sensor_id: int = 0
burst_mode: bool = True
temp_filtered: bool = False

View File

@@ -1,473 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the AnyskinSensor class for capturing tactile data from Anyskin tactile sensors.
"""
from doctest import master
import logging
import time
from multiprocessing import (
Event as process_Event,
JoinableQueue as process_Queue,
Process,
)
from pathlib import Path
from queue import Empty
from threading import Barrier, Event, Event as thread_Event, Thread
from typing import Any
from lerobot.utils.hub import T
import numpy as np
from serial import Serial, serialutil
from soundfile import SoundFile
from lerobot.utils.errors import (
DeviceAlreadyConnectedError,
DeviceAlreadyRecordingError,
DeviceNotConnectedError,
DeviceNotRecordingError,
)
from lerobot.utils.shared_array import SharedArray
from ..microphone import Microphone
from .configuration_anyskin import AnyskinSensorConfig
from anyskin import AnySkinBase, AnySkinDummy
logger = logging.getLogger(__name__)
MAX_MAGNETS_CHANNELS = 5
class AnyskinSensor(Microphone):
"""
The AnyskinSensor class handles all Anyskin tactile sensors.
A AnyskinSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
Example of usage:
```python
from lerobot.common.robot_devices.microphones.configs import AnyskinSensorConfig
config = AnyskinSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
microphone = AnyskinSensor(config)
microphone.connect()
microphone.start_recording("some/output/file.wav")
...
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
...
microphone.stop_recording()
microphone.disconnect()
```
"""
def __init__(self, config: AnyskinSensorConfig):
""" "
Initializes the AnyskinSensor instance.
Args:
config: The configuration settings for the sensor.
"""
super().__init__(config)
# Sensor port
self.sensor_port = config.sensor_port
# Baud rate
self.baud_rate = config.baud_rate
# Input audio recording process and events
self.record_process = None
self.record_stop_event = process_Event()
self.record_start_event = process_Event()
self.record_close_event = process_Event()
self.record_is_started_event = process_Event()
self.audio_callback_start_event = process_Event()
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
self.write_queue = process_Queue()
# SharedArray to store audio from the recording process.
self.read_shared_array = None
self.local_read_shared_array = None
# Thread/Process to handle data writing in a separate thread/process (safely)
self.write_thread = None
self.write_stop_event = None
self.write_is_started_event = None
self.logs = {}
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.sensor_port})"
@property
def is_connected(self) -> bool:
"""Check if the sensor is currently connected.
Returns:
bool: True if the sensor is connected and ready to start recording,
False otherwise.
"""
return self.record_process is not None and self.record_process.is_alive()
@property
def is_recording(self) -> bool:
"""Check if the sensor is currently recording.
Returns:
bool: True if the sensor is recording, False otherwise.
"""
return self.record_is_started_event.is_set()
@property
def is_writing(self) -> bool:
"""Check if the sensor is currently writing to a file.
Returns:
bool: True if the sensor is writing to a file, False otherwise.
"""
return self.write_thread is not None and self.write_is_started_event.is_set()
@staticmethod
def find_microphones() -> list[dict[str, Any]]:
"""Detects available sensors connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected sensor.
"""
pass
def connect(self) -> None:
"""
Establish connection to the sensor.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
# Create or reset queue and shared array
self.read_shared_array = SharedArray(
shape=(self.sample_rate * 10, len(self.channels)),
dtype=np.dtype("int16"),
)
self.local_read_shared_array = self.read_shared_array.get_local_array()
self.write_queue = process_Queue()
# Reset events
self.record_start_event.clear()
self.record_stop_event.clear()
self.record_close_event.clear()
self.record_is_started_event.clear()
self.audio_callback_start_event.clear()
# Create and start an audio input stream with a recording callback
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the busy_wait function.
process_init_event = process_Event()
self.record_process = Process(
target=self._record_process,
args=(
self.sensor_port,
self.baud_rate,
self.channels,
process_init_event,
self.record_start_event,
self.record_stop_event,
self.record_close_event,
self.record_is_started_event,
self.audio_callback_start_event,
self.write_queue,
self.read_shared_array,
),
)
self.record_process.daemon = True
self.record_process.start()
is_init = process_init_event.wait(
timeout=5.0
) # Wait for the recording process to be started, and to potentially raise an error on failure.
if not self.is_connected or not is_init:
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
logger.info(f"{self} connected.")
@staticmethod
def _record_process(
sensor_port,
baud_rate,
channels,
process_init_event,
record_start_event,
record_stop_event,
record_close_event,
record_is_started_event,
audio_callback_start_event,
write_queue,
read_shared_array,
) -> None:
channels_index = np.array(channels) - 1
local_read_shared_array = read_shared_array.get_local_array()
def tactile_callback(tactile_sensor: AnySkinBase):
"""
Parse the tactile data from the raw input data.
"""
if audio_callback_start_event.is_set():
timestamp, indata = tactile_sensor.get_sample()
indata = indata.reshape(-1, MAX_MAGNETS_CHANNELS)
write_queue.put_nowait(indata[:, channels_index])
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
try:
tactile_sensor = AnySkinBase(
num_mags=MAX_MAGNETS_CHANNELS,
port=sensor_port,
baudrate=baud_rate,
burst_mode=True,
device_id=0, #TODO(CarolinePascal): create an abstract increasing id for each sensor
temp_filtered=False,
) #TODO(CarolinePascal): add timeout on serial connection ?
except (serialutil.SerialException, AttributeError) as e:
raise RuntimeError(f"Error connecting sensor connected to {sensor_port}: {e}")
process_init_event.set()
while True:
start_flag = record_start_event.wait(timeout=0.1)
if record_close_event.is_set():
break
elif not start_flag:
continue
record_is_started_event.set()
while not record_stop_event.is_set():
tactile_callback(tactile_sensor) # Initial flush is already done in the constructor.
record_is_started_event.clear()
tactile_sensor.close() # Closes the inherited serial connection.
def disconnect(self) -> None:
"""
Disconnect the sensor and release any resources.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if self.is_recording:
self.stop_recording()
self.record_close_event.set()
self.read_shared_array.delete()
self.write_queue.close()
self.record_process.join()
if self.is_connected:
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
logger.info(f"{self} disconnected.")
def start_recording(
self,
output_file: str | Path | None = None,
multiprocessing: bool | None = False,
overwrite: bool | None = True,
barrier: Barrier | None = None,
) -> None:
"""
Start recording tactile data from the sensor.
Args:
output_file: Optional path to save the recorded tactile data.
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
overwrite: If True, overwrites existing files at output_file path.
barrier: If not None, ensures that multiple sensors start recording at the same time.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if self.is_recording:
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
# Reset queue and shared memory
self.read_shared_array.reset()
self._clear_queue(self.write_queue)
# Reset stop event
self.record_stop_event.clear()
# Write recordings into a file if output_file is provided
if output_file is not None:
output_file = Path(output_file)
output_file.parent.mkdir(parents=True, exist_ok=True)
if output_file.exists():
if overwrite:
output_file.unlink()
else:
raise FileExistsError(
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
)
if multiprocessing:
self.write_stop_event = process_Event()
self.write_is_started_event = process_Event()
self.write_thread = Process(
target=AnyskinSensor._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
else:
self.write_stop_event = thread_Event()
self.write_is_started_event = thread_Event()
self.write_thread = Thread(
target=AnyskinSensor._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
self.write_thread.daemon = True
self.write_thread.start()
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
self.record_start_event.set() # Start the input audio stream process
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
if barrier is not None:
barrier.wait() # Wait for multiple input audio streams to be started at the same time
self.audio_callback_start_event.set()
if not self.is_recording:
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
if output_file is not None and not self.is_writing:
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
def _read(self) -> np.ndarray:
"""
Thread/Process-safe callback to read available audio data
"""
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
def read(self) -> np.ndarray:
"""Capture and return a single audio chunk from the sensor.
Returns:
np.ndarray: Captured audio chunk as a numpy array.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if not self.is_recording:
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
start_time = time.perf_counter()
tactile_readings = self._read()
# log the number of seconds it took to read the audio chunk
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
# log the utc time at which the audio chunk was received
self.logs["timestamp_utc"] = time.perf_counter()
return tactile_readings
def _read_loop(self) -> None:
"""Internal loop run by the background thread for asynchronous reading."""
def stop_recording(self) -> None:
"""Stop recording audio from the sensor."""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if not self.is_recording:
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
self.audio_callback_start_event.clear()
self.record_start_event.clear() # Ensures the audio stream is not started again !
self.record_stop_event.set()
self.read_shared_array.reset()
self._clear_queue(self.write_queue, join_queue=True)
if self.is_writing:
self.write_stop_event.set()
self.write_thread.join()
timeout = 1.0
while self.is_recording and timeout > 0:
time.sleep(0.01)
timeout -= 0.01
if self.is_recording:
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
if self.is_writing:
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
def __del__(self) -> None:
if self.is_connected:
self.disconnect()
@staticmethod
def _clear_queue(queue, join_queue: bool = False):
"""
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
"""
try:
while True:
queue.get_nowait()
queue.task_done()
except Empty:
if join_queue:
queue.join()
return
@staticmethod
def _write_loop(
queue,
write_stop_event: Event,
write_is_started_event: Event,
sample_rate: int,
channels: list[int],
output_file: Path,
) -> None:
"""
Thread/Process-safe loop to write audio data into a file.
"""
# Can only be run on a single process/thread for file writing safety
with SoundFile(
output_file,
mode="w",
samplerate=sample_rate,
channels=len(channels),
format="WAV",
subtype="FLOAT", # Subtype for float32 values
) as file:
write_is_started_event.set()
while not write_stop_event.is_set():
try:
file.write(
queue.get(timeout=0.005)
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
queue.task_done()
except Empty:
continue
write_is_started_event.clear()

View File

@@ -1,140 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from pathlib import Path
from threading import Barrier
from typing import Any
import numpy as np
from .configs import MicrophoneConfig
class Microphone(abc.ABC):
"""Base class for microphone implementations.
Defines a standard interface for microphone operations across different backends.
Subclasses must implement all abstract methods.
Manages basic microphone properties (sample rate, channels) and core operations:
- Connection/disconnection
- Start/stop recording
- Audio chunk reading
Attributes:
sample_rate (int | None): Configured sample rate in Hz
channels (list[int] | None): List of channel numbers to record
Example:
class MyMicrophone(Microphone):
def __init__(self, config): ...
@property
def is_connected(self) -> bool: ...
def connect(self): ...
# Plus other required methods
"""
def __init__(self, config: MicrophoneConfig):
"""Initialize the microphone with the given configuration.
Args:
config: Microphone configuration containing sample rate and channels.
"""
self.sample_rate: int | None = config.sample_rate
self.channels: list[int] | None = config.channels
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if the microphone is currently connected.
Returns:
bool: True if the microphone is connected and ready to start recording,
False otherwise.
"""
pass
@property
@abc.abstractmethod
def is_recording(self) -> bool:
"""Check if the microphone is currently recording.
Returns:
bool: True if the microphone is recording, False otherwise.
"""
pass
@property
@abc.abstractmethod
def is_writing(self) -> bool:
"""Check if the microphone is currently writing to a file.
Returns:
bool: True if the microphone is writing to a file, False otherwise.
"""
pass
@staticmethod
@abc.abstractmethod
def find_microphones() -> list[dict[str, Any]]:
"""Detects available microphones connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected microphone.
"""
pass
@abc.abstractmethod
def connect(self) -> None:
"""Establish connection to the microphone."""
pass
@abc.abstractmethod
def start_recording(
self,
output_file: str | Path | None = None,
multiprocessing: bool | None = False,
overwrite: bool | None = True,
barrier: Barrier | None = None,
) -> None:
"""Start recording audio from the microphone.
Args:
output_file: Optional path to save the recorded audio.
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
overwrite: If True, overwrites existing files at output_file path.
barrier: If not None, ensures that multiple microphones start recording at the same time.
"""
pass
@abc.abstractmethod
def read(self) -> np.ndarray:
"""Capture and return a single audio chunk from the microphone.
Returns:
np.ndarray: Captured audio chunk as a numpy array.
"""
pass
@abc.abstractmethod
def stop_recording(self) -> None:
"""Stop recording audio from the microphone."""
pass
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnect the microphone and release any resources."""
pass

View File

@@ -1,41 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import MicrophoneConfig
@MicrophoneConfig.register_subclass("portaudio")
@dataclass
class PortAudioMicrophoneConfig(MicrophoneConfig):
"""Configuration class for PortAudio-based microphone devices.
This class provides configuration options for microphones accessed through PortAudio with the sounddevice Python package.
including device index, sample rate and channels.
Example configurations:
```python
# Basic configurations
PortAudioMicrophoneConfig(0, 16000, [1]) # Device index 0, 16000Hz, mono
PortAudioMicrophoneConfig(1, 44100, [1, 2]) # Device index 1, 44100Hz, stereo
```
Attributes:
microphone_index: Device index for the microphone.
sample_rate: Sample rate in Hz for the microphone.
channels: List of channel numbers to use for the microphone.
"""
microphone_index: int

View File

@@ -1,394 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import time
from collections.abc import Callable
from threading import Event, Thread
from typing import Any
import numpy as np
from sounddevice import PortAudioError
from lerobot.utils.robot_utils import precise_sleep
# --- Interface definitions for InputStream ---
class IInputStream(abc.ABC):
@abc.abstractmethod
def __init__(
self,
samplerate: float | None = None,
blocksize: int | None = None,
device: int | str | None = None,
channels: int | None = None,
dtype: str | np.dtype | None = None,
latency: float | str | None = None,
callback: Callable[[Any, int, Any, Any], None] | None = None,
):
pass
@abc.abstractmethod
def start(self) -> None:
pass
@abc.abstractmethod
def stop(self) -> None:
pass
@abc.abstractmethod
def close(self) -> None:
pass
class ISounddeviceSDK(abc.ABC):
"""Interface defining the contract for the Sounddevice SDK."""
InputStream: type[IInputStream]
@abc.abstractmethod
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
pass
# --- Real SDK Adapter ---
class SounddeviceSDKAdapter(ISounddeviceSDK):
"""Adapts the real sounddevice library to the ISounddeviceSDK interface."""
_sounddevice = None
def __init__(self):
try:
import sounddevice
SounddeviceSDKAdapter._sounddevice = sounddevice
except ImportError as e:
raise ImportError("sounddevice library not found") from e
# --- Inner Class Implementation ---
class RealInputStream(IInputStream):
def __init__(
self,
samplerate: int | None = None,
blocksize: int | None = None,
device: int | None = None,
channels: int | None = None,
dtype: str | np.dtype | None = None,
latency: float | str | None = None,
callback: Callable[[Any, int, Any, Any], None] | None = None,
):
import sounddevice
self._input_stream = sounddevice.InputStream(
samplerate=samplerate,
blocksize=blocksize,
device=device,
channels=channels,
dtype=dtype,
latency=latency,
callback=callback,
)
def start(self) -> None:
self._input_stream.start()
def stop(self) -> None:
self._input_stream.stop()
def close(self) -> None:
self._input_stream.close()
def __del__(self):
self._input_stream.stop()
self._input_stream.close()
@property
def active(self) -> bool:
return self._input_stream.active
@property
def stopped(self) -> bool:
return self._input_stream.stopped
@property
def closed(self) -> bool:
return self._input_stream.closed
InputStream = RealInputStream
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
return SounddeviceSDKAdapter._sounddevice.query_devices(device, kind)
# Emulates a 48kHz stereo microphone
VALID_DTYPE = {
"float32",
"int32",
"int16",
"int8",
"uint8",
np.float32,
np.int32,
np.int16,
np.int8,
np.uint8,
}
VALID_LATENCY = {"low", "high"}
VALID_DEVICES = [
{
"index": 0,
"name": "Built-in Microphone",
"hostapi": 0,
"max_input_channels": 2,
"max_output_channels": 0,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.001,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.01,
"default_samplerate": 48000.0,
},
{
"index": 1,
"name": "Built-in Output",
"hostapi": 0,
"max_input_channels": 0,
"max_output_channels": 2,
"default_low_input_latency": 0.04,
"default_low_output_latency": 0.04,
"default_high_input_latency": 0.12,
"default_high_output_latency": 0.12,
"default_samplerate": 48000.0,
},
{
"index": 2,
"name": "USB Audio Device",
"hostapi": 0,
"max_input_channels": 1,
"max_output_channels": 0,
"default_low_input_latency": 0.03,
"default_low_output_latency": 0.01,
"default_high_input_latency": 0.04,
"default_high_output_latency": 0.03,
"default_samplerate": 16000.0,
},
]
# -- Fake SDK Adapter ---
class FakeSounddeviceSDKAdapter(ISounddeviceSDK):
"""Implements the ISounddeviceSDK interface with fake behaviour for testing."""
# --- Inner Class Implementation ---
class FakeInputStream(IInputStream):
def __init__(
self,
samplerate: float | None = None,
blocksize: int | None = None,
device: int | str | None = None,
channels: int | None = None,
dtype: str | None = None,
latency: str | None = None,
callback: Callable[[Any, int, Any, Any], None] | None = None,
):
self.samplerate = samplerate
self.blocksize = blocksize
self.device = device
self.channels = channels
self.dtype = dtype
self.latency = latency
self.callback = callback
self._validate_settings()
self._active = False
self._closed = False
if self.callback is not None:
self._streaming_thread = Thread(target=self._streaming_loop, daemon=True)
self._streaming_thread_stop_event = Event()
@property
def active(self) -> bool:
"""True when the stream is active, False otherwise."""
return self._active
@property
def stopped(self) -> bool:
"""True when the stream is stopped, False otherwise."""
return not self._active
@property
def closed(self) -> bool:
"""True after a call to close(), False otherwise."""
return self._closed
def _get_device_info(self):
"""Returns the device info for the device."""
for device in VALID_DEVICES:
if (isinstance(self.device, int) and device["index"] == self.device) or (
isinstance(self.device, str) and device["name"] == self.device
):
return device
raise PortAudioError(f"No input device matching {self.device}")
def _validate_device(self):
"""Validates the device against the valid devices."""
valid_device_indices = [device["index"] for device in VALID_DEVICES]
valid_device_names = [device["name"] for device in VALID_DEVICES]
if self.device is not None:
if isinstance(self.device, (int, str)):
# Check if device index is valid
if isinstance(self.device, int) and self.device not in valid_device_indices:
raise PortAudioError(f"Error querying device {self.device}")
# Check if device name is valid
if isinstance(self.device, str) and self.device not in valid_device_names:
raise PortAudioError(f"No input device matching {self.device}")
else:
raise PortAudioError(f"Device must be int or str, got {type(self.device)}")
else:
# Default to first input device
input_devices = [d for d in VALID_DEVICES if d["max_input_channels"] > 0]
if input_devices:
self.device = input_devices[0]["index"]
def _validate_samplerate(self):
"""Validates the samplerate against the device's maximum samplerate."""
device_info = self._get_device_info()
if self.samplerate is None:
self.samplerate = device_info["default_samplerate"]
elif self.samplerate > device_info["default_samplerate"] or self.samplerate < 1000:
raise PortAudioError("Error opening InputStream: Invalid sample rate")
def _validate_channels(self):
"""Validates the channels against the device's maximum channels."""
device_info = self._get_device_info()
if self.channels is None:
self.channels = device_info["max_input_channels"]
elif self.channels > device_info["max_input_channels"] or self.channels < 1:
raise PortAudioError("Error opening InputStream: Invalid number of channels")
def _validate_dtype(self):
"""Validates the dtype against the valid dtypes."""
if self.dtype is not None:
if self.dtype not in VALID_DTYPE:
raise PortAudioError("Invalid input sample format")
else:
self.dtype = "float32" # Default dtype
def _validate_latency(self):
"""Validates the latency against the valid latencies."""
if self.latency is not None:
if self.latency not in VALID_LATENCY:
raise PortAudioError("Invalid latency")
else:
self.latency = "low" # Default latency
if isinstance(self.latency, str):
device_info = self._get_device_info()
if self.latency == "low":
self.latency = device_info["default_low_input_latency"]
elif self.latency == "high":
self.latency = device_info["default_high_input_latency"]
def _validate_settings(self):
"""Validates the input parameters against available devices and valid options."""
self._validate_device()
self._validate_samplerate()
self._validate_channels()
self._validate_dtype()
self._validate_latency()
def _simulated_audio_data(self) -> np.ndarray:
"""Generates a simulated audio signal for testing purposes with proper value ranges."""
duration_samples = int(self.samplerate * self.latency)
# Generate output according to dtype
if self.dtype in {"float32", np.float32}:
# Generate values between -1 and 1 for float32
data = np.random.uniform(-1.0, 1.0, (duration_samples, self.channels)).astype(self.dtype)
else:
# Use np.iinfo to get proper range for integer types
info = np.iinfo(self.dtype)
data = np.random.randint(
info.min, info.max + 1, (duration_samples, self.channels), dtype=self.dtype
)
return data
def _streaming_loop(self):
if self.callback is not None:
while not self._streaming_thread_stop_event.is_set():
precise_sleep(self.latency)
tmp_data = self._simulated_audio_data()
self.callback(
tmp_data,
len(tmp_data),
time.perf_counter(),
None,
)
def start(self) -> None:
"""Start the fake input stream."""
if not self.active and self.callback is not None:
self._streaming_thread.start()
self._active = True
def stop(self) -> None:
"""Stop the fake input stream."""
if self.callback is not None:
self._streaming_thread_stop_event.set()
self._streaming_thread.join()
self._active = False
def close(self) -> None:
"""Close the fake input stream."""
if self.active and self.callback is not None:
self.stop()
self._active = False
self._closed = True
def __del__(self):
self.close()
InputStream = FakeInputStream
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
"""Returns a realistic list of audio devices including speakers and microphones."""
if device is not None:
# Return specific device
for valid_device in VALID_DEVICES:
if (isinstance(device, int) and valid_device["index"] == device) or (
isinstance(device, str) and valid_device["name"] == device
):
return valid_device
raise PortAudioError(f"Error querying device {device}")
elif kind is not None:
for valid_device in VALID_DEVICES:
if (
valid_device["max_input_channels"] > 0
and kind == "input"
or valid_device["max_output_channels"] > 0
and kind == "output"
):
return valid_device
raise PortAudioError(f"No {kind} device found")
return VALID_DEVICES

View File

@@ -1,566 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the PortAudioMicrophone class for capturing audio from microphones using the PortAudio library through the sounddevice Python package.
"""
import logging
import time
from multiprocessing import (
Event as process_Event,
JoinableQueue as process_Queue,
Process,
)
from pathlib import Path
from queue import Empty
from threading import Barrier, Event, Event as thread_Event, Thread
from typing import Any
import numpy as np
from soundfile import SoundFile
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
from lerobot.utils.errors import (
DeviceAlreadyConnectedError,
DeviceAlreadyRecordingError,
DeviceNotConnectedError,
DeviceNotRecordingError,
)
from lerobot.utils.shared_array import SharedArray
from ..microphone import Microphone
from .configuration_portaudio import PortAudioMicrophoneConfig
logger = logging.getLogger(__name__)
class PortAudioMicrophone(Microphone):
"""
The PortAudioMicrophone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
A PortAudioMicrophone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
Example of usage:
```python
from lerobot.common.robot_devices.microphones.configs import PortAudioMicrophoneConfig
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
microphone = PortAudioMicrophone(config)
microphone.connect()
microphone.start_recording("some/output/file.wav")
...
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
...
microphone.stop_recording()
microphone.disconnect()
```
"""
def __init__(self, config: PortAudioMicrophoneConfig, sounddevice_sdk: ISounddeviceSDK = None):
"""
Initializes the PortAudioMicrophone instance.
Args:
config: The configuration settings for the microphone.
"""
super().__init__(config)
if sounddevice_sdk is None:
self.sounddevice_sdk = SounddeviceSDKAdapter()
else:
self.sounddevice_sdk = sounddevice_sdk
# Microphone index
self.microphone_index = config.microphone_index
# Input audio recording process and events
self.record_process = None
self.record_stop_event = process_Event()
self.record_start_event = process_Event()
self.record_close_event = process_Event()
self.record_is_started_event = process_Event()
self.audio_callback_start_event = process_Event()
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
self.write_queue = process_Queue()
# SharedArray to store audio from the recording process.
self.read_shared_array = None
self.local_read_shared_array = None
# Thread/Process to handle data writing in a separate thread/process (safely)
self.write_thread = None
self.write_stop_event = None
self.write_is_started_event = None
self.logs = {}
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.microphone_index})"
@property
def is_connected(self) -> bool:
return self.record_process is not None and self.record_process.is_alive()
@property
def is_recording(self) -> bool:
return self.record_is_started_event.is_set()
@property
def is_writing(self) -> bool:
return self.write_thread is not None and self.write_is_started_event.is_set()
@staticmethod
def find_microphones(
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
) -> list[dict[str, Any]] | dict[str, Any]:
"""
Detects available microphones connected to the system.
Args:
device: The device to find microphones for. If None, all microphones are found.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
"""
if sounddevice_sdk is None:
sounddevice_sdk = SounddeviceSDKAdapter()
found_microphones_info = []
devices = sounddevice_sdk.query_devices()
for d in devices:
if d["max_input_channels"] > 0:
microphone_info = {
"index": d["index"],
"name": d["name"],
"sample_rate": int(d["default_samplerate"]),
"channels": np.arange(1, d["max_input_channels"] + 1),
}
if device is None or (
(isinstance(device, int) and d["index"] == device)
or (isinstance(device, str) and d["name"] == device)
):
found_microphones_info.append(microphone_info)
if device is not None:
if len(found_microphones_info) == 0:
raise RuntimeError(f"No microphone found for device {device}")
else:
return found_microphones_info[0]
if len(found_microphones_info) == 0:
logger.warning("No microphone found !")
return found_microphones_info
def _configure_capture_settings(self) -> None:
"""
Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone.
This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream.
Raises:
RuntimeError: If one of the specified settings is not compatible with the microphone.
DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"Cannot configure settings for {self} as it is already connected."
)
self._validate_microphone_index()
self._validate_sample_rate()
self._validate_channels()
def _validate_microphone_index(self) -> None:
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
try:
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
except RuntimeError as e:
raise RuntimeError(
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
) from e
def _validate_sample_rate(self) -> None:
"""Validates the sample rate against the actual microphone's default sample rate."""
actual_sample_rate = PortAudioMicrophone.find_microphones(
self.microphone_index, self.sounddevice_sdk
)["sample_rate"]
if self.sample_rate is not None:
try:
self.sample_rate = int(self.sample_rate)
except ValueError as e:
raise RuntimeError(
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
) from e
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
raise RuntimeError(
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
)
else:
if self.sample_rate < actual_sample_rate:
logger.warning(
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
)
else:
self.sample_rate = actual_sample_rate
def _validate_channels(self) -> None:
"""Validates the channels against the actual microphone's maximum input channels."""
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
"channels"
]
if self.channels is not None and len(self.channels) > 0:
if not all(channel in actual_channels for channel in self.channels):
raise RuntimeError(
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
)
else:
self.channels = actual_channels
# Get channels index instead of number for slicing
self.channels_index = np.array(self.channels) - 1
def connect(self) -> None:
"""
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
self._configure_capture_settings()
# Create or reset queue and shared array
self.read_shared_array = SharedArray(
shape=(self.sample_rate * 10, len(self.channels)),
dtype=np.dtype("float32"),
)
self.local_read_shared_array = self.read_shared_array.get_local_array()
self.write_queue = process_Queue()
# Reset events
self.record_start_event.clear()
self.record_stop_event.clear()
self.record_close_event.clear()
self.record_is_started_event.clear()
self.audio_callback_start_event.clear()
# Create and start an audio input stream with a recording callback
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
process_init_event = process_Event()
self.record_process = Process(
target=self._record_process,
args=(
self.microphone_index,
self.sample_rate,
self.channels,
process_init_event,
self.record_start_event,
self.record_stop_event,
self.record_close_event,
self.record_is_started_event,
self.audio_callback_start_event,
self.write_queue,
self.read_shared_array,
self.sounddevice_sdk,
),
)
self.record_process.daemon = True
self.record_process.start()
is_init = process_init_event.wait(
timeout=5.0
) # Wait for the recording process to be started, and to potentially raise an error on failure.
if not self.is_connected or not is_init:
raise RuntimeError(f"Error connecting microphone {self.microphone_index}.")
logger.info(f"{self} connected.")
def disconnect(self) -> None:
"""
Disconnects the microphone and stops the recording.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording:
self.stop_recording()
self.record_close_event.set()
self.read_shared_array.delete()
self.write_queue.close()
self.record_process.join()
if self.is_connected:
raise RuntimeError(f"Error disconnecting microphone {self.microphone_index}.")
logger.info(f"{self} disconnected.")
def _read(self) -> np.ndarray:
"""
Thread/Process-safe callback to read available audio data
"""
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
def read(self) -> np.ndarray:
"""
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording:
raise RuntimeError(f"Microphone {self.microphone_index} is not recording.")
start_time = time.perf_counter()
audio_readings = self._read()
# log the number of seconds it took to read the audio chunk
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
# log the utc time at which the audio chunk was received
self.logs["timestamp_utc"] = time.perf_counter()
return audio_readings
@staticmethod
def _record_process(
microphone_index,
sample_rate,
channels,
process_init_event,
record_start_event,
record_stop_event,
record_close_event,
record_is_started_event,
audio_callback_start_event,
write_queue,
read_shared_array,
sounddevice_sdk,
) -> None:
"""
Process callback used to create an unpickable sounddevice audio input stream with a recording callback and start, stop and close it based on multiprocessing events.
"""
channels_index = np.array(channels) - 1
local_read_shared_array = read_shared_array.get_local_array()
def audio_callback(indata, frames, timestamp, status) -> None:
"""
Low-level sounddevice callback.
"""
if status:
logger.warning(status)
if audio_callback_start_event.is_set():
write_queue.put_nowait(indata[:, channels_index])
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
# Create the audio stream
# InputStream must be instantiated in the process as it is not pickable.
stream = sounddevice_sdk.InputStream(
device=microphone_index,
samplerate=sample_rate,
channels=max(channels),
dtype="float32",
blocksize=0, # Varying input buffer length, but no additional latency
latency="low", # Low latency mode (not enabled by default !)
# never_drop_input=True, # Disabled as it generates an error for some devices
callback=audio_callback,
)
process_init_event.set()
while True:
start_flag = record_start_event.wait(timeout=0.1)
if record_close_event.is_set():
break
elif not start_flag:
continue
stream.start()
record_is_started_event.set()
record_stop_event.wait()
stream.stop() # stream.stop() waits for all buffers to be processed, stream.abort() flushes the buffers !
record_is_started_event.clear()
stream.close()
def start_recording(
self,
output_file: str | None = None,
multiprocessing: bool | None = False,
overwrite: bool | None = True,
barrier: Barrier | None = None,
) -> None:
"""
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording:
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
# Reset queue and shared memory
self.read_shared_array.reset()
self._clear_queue(self.write_queue)
# Reset stop event
self.record_stop_event.clear()
# Write recordings into a file if output_file is provided
if output_file is not None:
output_file = Path(output_file)
output_file.parent.mkdir(parents=True, exist_ok=True)
if output_file.exists():
if overwrite:
output_file.unlink()
else:
raise FileExistsError(
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
)
if multiprocessing:
self.write_stop_event = process_Event()
self.write_is_started_event = process_Event()
self.write_thread = Process(
target=PortAudioMicrophone._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
else:
self.write_stop_event = thread_Event()
self.write_is_started_event = thread_Event()
self.write_thread = Thread(
target=PortAudioMicrophone._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
self.write_thread.daemon = True
self.write_thread.start()
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
self.record_start_event.set() # Start the input audio stream process
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
if barrier is not None:
barrier.wait() # Wait for multiple input audio streams to be started at the same time
self.audio_callback_start_event.set()
if not self.is_recording:
raise RuntimeError(f"Error starting recording for microphone {self.microphone_index}.")
if output_file is not None and not self.is_writing:
raise RuntimeError(f"Error starting writing for microphone {self.microphone_index}.")
def stop_recording(self) -> None:
"""
Stops the recording of the microphones.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording:
raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
self.audio_callback_start_event.clear()
self.record_start_event.clear() # Ensures the audio stream is not started again !
self.record_stop_event.set()
# Wait for the stream to be stopped (might lead to race condition if the stream is not properly stopped on array reset and queue clearing)
timeout = 1.0
while self.is_recording and timeout > 0:
time.sleep(0.01)
timeout -= 0.01
self.read_shared_array.reset()
self._clear_queue(self.write_queue, join_queue=True)
if self.is_writing:
self.write_stop_event.set()
self.write_thread.join()
if self.is_recording:
raise RuntimeError(f"Error stopping recording for microphone {self.microphone_index}.")
if self.is_writing:
raise RuntimeError(f"Error stopping writing for microphone {self.microphone_index}.")
@staticmethod
def _write_loop(
queue,
write_stop_event: Event,
write_is_started_event: Event,
sample_rate: int,
channels: list[int],
output_file: Path,
) -> None:
"""
Thread/Process-safe loop to write audio data into a file.
"""
# Can only be run on a single process/thread for file writing safety
with SoundFile(
output_file,
mode="w",
samplerate=sample_rate,
channels=len(channels),
format="WAV",
subtype="FLOAT", # By default, a much lower quality WAV file is created !
) as file:
write_is_started_event.set()
while not write_stop_event.is_set():
try:
file.write(
queue.get(timeout=0.005)
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
queue.task_done()
except Empty:
continue
write_is_started_event.clear()
def __del__(self) -> None:
if self.is_connected:
self.disconnect()
@staticmethod
def _clear_queue(queue, join_queue: bool = False):
"""
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
"""
try:
while True:
queue.get_nowait()
queue.task_done()
except Empty:
if join_queue:
queue.join()
return

View File

@@ -1,42 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import MicrophoneConfig
@MicrophoneConfig.register_subclass("touchlab")
@dataclass
class TouchLabSensorConfig(MicrophoneConfig):
"""Configuration class for TouchLab tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
This class provides configuration options for TouchLab tactile sensors, including serial port, sample rate and channels.
Example configurations:
```python
# Basic configurations
TouchLabSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
TouchLabSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
```
Attributes:
sensor_port: Serial port of the tactile sensor.
baud_rate: Baud rate of the tactile sensor.
sample_rate: Sample rate in Hz for the tactile sensor.
channels: List of channel numbers to use for the tactile sensor.
"""
sensor_port: str
baud_rate: int = 115_200

View File

@@ -1,469 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the TouchLabSensor class for capturing tactile data from TouchLab tactile sensors.
"""
import logging
import time
from multiprocessing import (
Event as process_Event,
JoinableQueue as process_Queue,
Process,
)
from pathlib import Path
from queue import Empty
from threading import Barrier, Event, Event as thread_Event, Thread
from typing import Any
import numpy as np
from serial import Serial
from soundfile import SoundFile
from lerobot.utils.errors import (
DeviceAlreadyConnectedError,
DeviceAlreadyRecordingError,
DeviceNotConnectedError,
DeviceNotRecordingError,
)
from lerobot.utils.shared_array import SharedArray
from ..microphone import Microphone
from .configuration_touchlab import TouchLabSensorConfig
logger = logging.getLogger(__name__)
MAX_SERIAL_READ_SIZE = 512
class TouchLabSensor(Microphone):
"""
The TouchLabSensor class handles all TouchLab tactile sensors.
A TouchLabSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
Example of usage:
```python
from lerobot.common.robot_devices.microphones.configs import TouchLabSensorConfig
config = TouchLabSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
microphone = TouchLabSensor(config)
microphone.connect()
microphone.start_recording("some/output/file.wav")
...
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
...
microphone.stop_recording()
microphone.disconnect()
```
"""
def __init__(self, config: TouchLabSensorConfig):
""" "
Initializes the TouchLabSensor instance.
Args:
config: The configuration settings for the sensor.
"""
super().__init__(config)
# Sensor port
self.sensor_port = config.sensor_port
# Baud rate
self.baud_rate = config.baud_rate
# Input audio recording process and events
self.record_process = None
self.record_stop_event = process_Event()
self.record_start_event = process_Event()
self.record_close_event = process_Event()
self.record_is_started_event = process_Event()
self.audio_callback_start_event = process_Event()
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
self.write_queue = process_Queue()
# SharedArray to store audio from the recording process.
self.read_shared_array = None
self.local_read_shared_array = None
# Thread/Process to handle data writing in a separate thread/process (safely)
self.write_thread = None
self.write_stop_event = None
self.write_is_started_event = None
self.logs = {}
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.sensor_port})"
@property
def is_connected(self) -> bool:
"""Check if the sensor is currently connected.
Returns:
bool: True if the sensor is connected and ready to start recording,
False otherwise.
"""
return self.record_process is not None and self.record_process.is_alive()
@property
def is_recording(self) -> bool:
"""Check if the sensor is currently recording.
Returns:
bool: True if the sensor is recording, False otherwise.
"""
return self.record_is_started_event.is_set()
@property
def is_writing(self) -> bool:
"""Check if the sensor is currently writing to a file.
Returns:
bool: True if the sensor is writing to a file, False otherwise.
"""
return self.write_thread is not None and self.write_is_started_event.is_set()
@staticmethod
def find_microphones() -> list[dict[str, Any]]:
"""Detects available sensors connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected sensor.
"""
pass
def connect(self) -> None:
"""
Establish connection to the sensor.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
# Create or reset queue and shared array
self.read_shared_array = SharedArray(
shape=(self.sample_rate * 10, len(self.channels)),
dtype=np.dtype("int16"),
)
self.local_read_shared_array = self.read_shared_array.get_local_array()
self.write_queue = process_Queue()
# Reset events
self.record_start_event.clear()
self.record_stop_event.clear()
self.record_close_event.clear()
self.record_is_started_event.clear()
self.audio_callback_start_event.clear()
# Create and start an audio input stream with a recording callback
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
process_init_event = process_Event()
self.record_process = Process(
target=self._record_process,
args=(
self.sensor_port,
self.baud_rate,
self.channels,
process_init_event,
self.record_start_event,
self.record_stop_event,
self.record_close_event,
self.record_is_started_event,
self.audio_callback_start_event,
self.write_queue,
self.read_shared_array,
),
)
self.record_process.daemon = True
self.record_process.start()
is_init = process_init_event.wait(
timeout=5.0
) # Wait for the recording process to be started, and to potentially raise an error on failure.
if not self.is_connected or not is_init:
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
logger.info(f"{self} connected.")
@staticmethod
def _record_process(
sensor_port,
baud_rate,
channels,
process_init_event,
record_start_event,
record_stop_event,
record_close_event,
record_is_started_event,
audio_callback_start_event,
write_queue,
read_shared_array,
) -> None:
channels_index = np.array(channels) - 1
local_read_shared_array = read_shared_array.get_local_array()
def tactile_callback(serial_connection):
"""
Parse the tactile data from the raw input data.
"""
buffer = serial_connection.readline()
if audio_callback_start_event.is_set():
strings = buffer.decode("utf8").split(",")
num_taxels = len(strings)
if num_taxels > 0 and num_taxels < MAX_SERIAL_READ_SIZE: # Make sure we didn't read rubbish
indata = np.empty((1, num_taxels))
for i in range(num_taxels):
indata[0, i] = int(strings[i])
write_queue.put_nowait(indata[:, channels_index])
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
process_init_event.set()
while True:
start_flag = record_start_event.wait(timeout=0.1)
if record_close_event.is_set():
break
elif not start_flag:
continue
with Serial(sensor_port, baud_rate, timeout=0.5) as serial_connection:
serial_connection.flush()
record_is_started_event.set()
while not record_stop_event.is_set():
tactile_callback(serial_connection)
record_is_started_event.clear()
serial_connection.close()
def disconnect(self) -> None:
"""
Disconnect the sensor and release any resources.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if self.is_recording:
self.stop_recording()
self.record_close_event.set()
self.read_shared_array.delete()
self.write_queue.close()
self.record_process.join()
if self.is_connected:
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
logger.info(f"{self} disconnected.")
def start_recording(
self,
output_file: str | Path | None = None,
multiprocessing: bool | None = False,
overwrite: bool | None = True,
barrier: Barrier | None = None,
) -> None:
"""
Start recording tactile data from the sensor.
Args:
output_file: Optional path to save the recorded tactile data.
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
overwrite: If True, overwrites existing files at output_file path.
barrier: If not None, ensures that multiple sensors start recording at the same time.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if self.is_recording:
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
# Reset queue and shared memory
self.read_shared_array.reset()
self._clear_queue(self.write_queue)
# Reset stop event
self.record_stop_event.clear()
# Write recordings into a file if output_file is provided
if output_file is not None:
output_file = Path(output_file)
output_file.parent.mkdir(parents=True, exist_ok=True)
if output_file.exists():
if overwrite:
output_file.unlink()
else:
raise FileExistsError(
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
)
if multiprocessing:
self.write_stop_event = process_Event()
self.write_is_started_event = process_Event()
self.write_thread = Process(
target=TouchLabSensor._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
else:
self.write_stop_event = thread_Event()
self.write_is_started_event = thread_Event()
self.write_thread = Thread(
target=TouchLabSensor._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
self.write_thread.daemon = True
self.write_thread.start()
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
self.record_start_event.set() # Start the input audio stream process
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
if barrier is not None:
barrier.wait() # Wait for multiple input audio streams to be started at the same time
self.audio_callback_start_event.set()
if not self.is_recording:
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
if output_file is not None and not self.is_writing:
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
def _read(self) -> np.ndarray:
"""
Thread/Process-safe callback to read available audio data
"""
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
def read(self) -> np.ndarray:
"""Capture and return a single audio chunk from the sensor.
Returns:
np.ndarray: Captured audio chunk as a numpy array.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if not self.is_recording:
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
start_time = time.perf_counter()
tactile_readings = self._read()
# log the number of seconds it took to read the audio chunk
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
# log the utc time at which the audio chunk was received
self.logs["timestamp_utc"] = time.perf_counter()
return tactile_readings
def _read_loop(self) -> None:
"""Internal loop run by the background thread for asynchronous reading."""
def stop_recording(self) -> None:
"""Stop recording audio from the sensor."""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if not self.is_recording:
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
self.audio_callback_start_event.clear()
self.record_start_event.clear() # Ensures the audio stream is not started again !
self.record_stop_event.set()
self.read_shared_array.reset()
self._clear_queue(self.write_queue, join_queue=True)
if self.is_writing:
self.write_stop_event.set()
self.write_thread.join()
timeout = 1.0
while self.is_recording and timeout > 0:
time.sleep(0.01)
timeout -= 0.01
if self.is_recording:
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
if self.is_writing:
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
def __del__(self) -> None:
if self.is_connected:
self.disconnect()
@staticmethod
def _clear_queue(queue, join_queue: bool = False):
"""
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
"""
try:
while True:
queue.get_nowait()
queue.task_done()
except Empty:
if join_queue:
queue.join()
return
@staticmethod
def _write_loop(
queue,
write_stop_event: Event,
write_is_started_event: Event,
sample_rate: int,
channels: list[int],
output_file: Path,
) -> None:
"""
Thread/Process-safe loop to write audio data into a file.
"""
# Can only be run on a single process/thread for file writing safety
with SoundFile(
output_file,
mode="w",
samplerate=sample_rate,
channels=len(channels),
format="WAV",
subtype="PCM_16", # Subtype for int16 values
) as file:
write_is_started_event.set()
while not write_stop_event.is_set():
try:
file.write(
queue.get(timeout=0.005)
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
queue.task_done()
except Empty:
continue
write_is_started_event.clear()

View File

@@ -1,93 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing import Barrier
from threading import Thread
from .configs import MicrophoneConfig
from .microphone import Microphone
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfig]) -> dict[str, Microphone]:
microphones = {}
for key, cfg in microphone_configs.items():
if cfg.type == "portaudio":
from .portaudio import PortAudioMicrophone
microphones[key] = PortAudioMicrophone(cfg)
elif cfg.type == "touchlab":
from .touchlab import TouchLabSensor
microphones[key] = TouchLabSensor(cfg)
elif cfg.type == "anyskin":
from .anyskin import AnyskinSensor
microphones[key] = AnyskinSensor(cfg)
else:
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
return microphones
def async_microphones_start_recording(
microphones: dict[str, Microphone],
output_files: list[str | None] | None = None,
multiprocessing: bool = False,
overwrite: bool = True,
) -> None:
"""
Starts recording on multiple microphones asynchronously to avoid delays.
Args:
microphones: A dictionary of microphones.
output_files: A list of output files.
multiprocessing: If True, enables multiprocessing for recording.
overwrite: If True, overwrites existing files at output_file path.
"""
start_recording_threads = []
if output_files is None:
output_files = [None] * len(microphones)
barrier = Barrier(len(microphones))
for microphone, output_file in zip(microphones.values(), output_files, strict=False):
start_recording_threads.append(
Thread(target=microphone.start_recording, args=(output_file, multiprocessing, overwrite, barrier))
)
for thread in start_recording_threads:
thread.start()
for thread in start_recording_threads:
thread.join()
def async_microphones_stop_recording(microphones: dict[str, Microphone]) -> None:
"""
Stops recording on multiple microphones asynchronously to avoid delays.
Args:
microphones: A dictionary of microphones.
"""
stop_recording_threads = []
for microphone in microphones.values():
stop_recording_threads.append(Thread(target=microphone.stop_recording))
for thread in stop_recording_threads:
thread.start()
for thread in stop_recording_threads:
thread.join()

View File

@@ -14,4 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
from .motors_bus import (
Motor,
MotorCalibration,
MotorNormMode,
)

View File

@@ -18,7 +18,7 @@ from dataclasses import dataclass
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
from lerobot.motors import MotorCalibration, MotorsBus
from .motors_bus import MotorCalibration, MotorsBus
BAR_LEN, BAR_THICKNESS = 450, 8
HANDLE_R = 10
@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(groups)
self.group_names = list(self.groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
@@ -230,18 +230,20 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")

View File

@@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,5 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_anyskin import AnyskinSensorConfig
from .sensor_anyskin import AnyskinSensor
from .damiao import DamiaoMotorsBus
from .tables import *

View File

@@ -0,0 +1,859 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Portions of this file are derived from DM_Control_Python by cmjang.
# Licensed under the MIT License; see `LICENSE` for the full text:
# https://github.com/cmjang/DM_Control_Python
import logging
import time
from contextlib import contextmanager
from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
else:
class can: # noqa: N801
Message = object
interface = None
import numpy as np
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value
from .tables import (
AVAILABLE_BAUDRATES,
CAN_CMD_DISABLE,
CAN_CMD_ENABLE,
CAN_CMD_REFRESH,
CAN_CMD_SET_ZERO,
CAN_PARAM_ID,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
MIT_KD_RANGE,
MIT_KP_RANGE,
MOTOR_LIMIT_PARAMS,
MotorType,
)
logger = logging.getLogger(__name__)
LONG_TIMEOUT_SEC = 0.1
MEDIUM_TIMEOUT_SEC = 0.01
SHORT_TIMEOUT_SEC = 0.001
PRECISE_TIMEOUT_SEC = 0.0001
class MotorState(TypedDict):
position: float
velocity: float
torque: float
temp_mos: float
temp_rotor: float
class DamiaoMotorsBus(MotorsBusBase):
"""
The Damiao implementation for a MotorsBus using CAN bus communication.
This class uses python-can for CAN bus communication with Damiao motors.
For more info, see:
- python-can documentation: https://python-can.readthedocs.io/en/stable/
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
"""
# CAN-specific settings
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
can_interface: str = "auto",
use_can_fd: bool = True,
bitrate: int = 1000000,
data_bitrate: int | None = 5000000,
):
"""
Initialize the Damiao motors bus.
Args:
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
motors: Dictionary mapping motor names to Motor objects
calibration: Optional calibration data
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
self.use_can_fd = use_can_fd
self.bitrate = bitrate
self.data_bitrate = data_bitrate
self.canbus: can.interface.Bus | None = None
self._is_connected = False
# Map motor names to CAN IDs
self._motor_can_ids: dict[str, int] = {}
self._recv_id_to_motor: dict[int, str] = {}
self._motor_types: dict[str, MotorType] = {}
for name, motor in self.motors.items():
if motor.motor_type_str is None:
raise ValueError(f"Motor '{name}' is missing required 'motor_type'")
self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_"))
# Map recv_id to motor name for filtering responses
if motor.recv_id is not None:
self._recv_id_to_motor[motor.recv_id] = name
# State cache for handling packet drops safely
self._last_known_states: dict[str, MotorState] = {
name: {
"position": 0.0,
"velocity": 0.0,
"torque": 0.0,
"temp_mos": 0.0,
"temp_rotor": 0.0,
}
for name in self.motors
}
# Dynamic gains storage
# Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping)
self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors}
@property
def is_connected(self) -> bool:
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
Args:
handshake: If True, ping all motors to verify they're present
"""
try:
# Auto-detect interface type based on port name
if self.can_interface == "auto":
if self.port.startswith("/dev/"):
self.can_interface = "slcan"
logger.info(f"Auto-detected slcan interface for port {self.port}")
else:
self.can_interface = "socketcan"
logger.info(f"Auto-detected socketcan interface for port {self.port}")
# Connect to CAN bus
kwargs = {
"channel": self.port,
"bitrate": self.bitrate,
"interface": self.can_interface,
}
if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None:
kwargs.update({"data_bitrate": self.data_bitrate, "fd": True})
logger.info(
f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})"
)
else:
logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})")
self.canbus = can.interface.Bus(**kwargs)
self._is_connected = True
if handshake:
self._handshake()
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
except Exception as e:
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _handshake(self) -> None:
"""
Verify all motors are present and populate initial state cache.
Raises ConnectionError if any motor fails to respond.
"""
logger.info("Starting handshake with motors...")
# Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01):
pass
missing_motors = []
for motor_name in self.motors:
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
# Send enable command
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
# Wait for response with longer timeout
response = None
start_time = time.time()
while time.time() - start_time < 0.1:
response = self.canbus.recv(timeout=0.1)
if response and response.arbitration_id == recv_id:
break
response = None
if response is None:
missing_motors.append(motor_name)
else:
self._process_response(motor_name, msg)
time.sleep(MEDIUM_TIMEOUT_SEC)
if missing_motors:
raise ConnectionError(
f"Handshake failed. The following motors did not respond: {missing_motors}. "
"Check power (24V) and CAN wiring."
)
logger.info("Handshake successful. All motors ready.")
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if disable_torque:
try:
self.disable_torque()
except Exception as e:
logger.warning(f"Failed to disable torque during disconnect: {e}")
if self.canbus:
self.canbus.shutdown()
self.canbus = None
self._is_connected = False
logger.debug(f"{self.__class__.__name__} disconnected.")
def configure_motors(self) -> None:
"""Configure all motors with default settings."""
# Damiao motors don't require much configuration in MIT mode
# Just ensure they're enabled
for motor in self.motors:
self._send_simple_command(motor, CAN_CMD_ENABLE)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None:
"""Helper to send simple 8-byte commands (Enable, Disable, Zero)."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}")
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_ENABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
for _ in range(num_retry + 1):
try:
self._send_simple_command(motor, CAN_CMD_DISABLE)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(MEDIUM_TIMEOUT_SEC)
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
"""
Context manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
"""
self.disable_torque(motors)
try:
yield
finally:
self.enable_torque(motors)
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
"""Set current position as zero for selected motors."""
target_motors = self._get_motors_list(motors)
for motor in target_motors:
self._send_simple_command(motor, CAN_CMD_SET_ZERO)
time.sleep(MEDIUM_TIMEOUT_SEC)
def _refresh_motor(self, motor: NameOrID) -> can.Message | None:
"""Refresh motor status and return the response."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
def _recv_motor_response(
self, expected_recv_id: int | None = None, timeout: float = 0.001
) -> can.Message | None:
"""
Receive a response from a motor.
Args:
expected_recv_id: If provided, only return messages from this CAN ID
timeout: Timeout in seconds (default: 1ms for high-speed operation)
Returns:
CAN message if received, None otherwise
"""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
start_time = time.time()
messages_seen = []
while time.time() - start_time < timeout:
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg:
messages_seen.append(f"0x{msg.arbitration_id:02X}")
if expected_recv_id is None or msg.arbitration_id == expected_recv_id:
return msg
logger.debug(
f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}"
)
if logger.isEnabledFor(logging.DEBUG):
if messages_seen:
logger.debug(
f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}"
)
else:
logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})")
except Exception as e:
logger.debug(f"Failed to receive CAN message: {e}")
return None
def _recv_all_responses(
self, expected_recv_ids: list[int], timeout: float = 0.002
) -> dict[int, can.Message]:
"""
Efficiently receive responses from multiple motors at once.
Uses the OpenArms pattern: collect all available messages within timeout.
Args:
expected_recv_ids: List of CAN IDs we expect responses from
timeout: Total timeout in seconds (default: 2ms)
Returns:
Dictionary mapping recv_id to CAN message
"""
responses: dict[int, can.Message] = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC)
if msg and msg.arbitration_id in expected_set:
responses[msg.arbitration_id] = msg
if len(responses) == len(expected_recv_ids):
break
except Exception as e:
logger.debug(f"Error receiving responses: {e}")
return responses
def _encode_mit_packet(
self,
motor_type: MotorType,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> list[int]:
"""Helper to encode control parameters into 8 bytes for MIT mode."""
# Convert degrees to radians
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12)
kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
return data
def _mit_control(
self,
motor: NameOrID,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> None:
"""Send MIT control command to a motor."""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
else:
logger.debug(f"No response from {motor_name} after MIT control command")
def _mit_control_batch(
self,
commands: dict[NameOrID, tuple[float, float, float, float, float]],
) -> None:
"""
Send MIT control commands to multiple motors in batch.
Sends all commands first, then collects responses.
Args:
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
"""
if not commands:
return
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
x = max(x_min, min(x_max, x)) # Clamp to range
span = x_max - x_min
data_norm = (x - x_min) / span
return int(data_norm * ((1 << bits) - 1))
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
"""Convert unsigned integer from CAN to float."""
span = x_max - x_min
data_norm = float(x) / ((1 << bits) - 1)
return data_norm * span + x_min
def _decode_motor_state(
self, data: bytearray | bytes, motor_type: MotorType
) -> tuple[float, float, float, int, int]:
"""
Decode motor state from CAN data.
Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor)
"""
if len(data) < 8:
raise ValueError("Invalid motor state data")
# Extract encoded values
q_uint = (data[1] << 8) | data[2]
dq_uint = (data[3] << 4) | (data[4] >> 4)
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
t_mos = data[6]
t_rotor = data[7]
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Decode to physical values
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor
def _process_response(self, motor: str, msg: can.Message) -> None:
"""Decode a message and update the motor state cache."""
try:
motor_type = self._motor_types[motor]
pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
self._last_known_states[motor] = {
"position": pos,
"velocity": vel,
"torque": torque,
"temp_mos": float(t_mos),
"temp_rotor": float(t_rotor),
}
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
@check_if_not_connected
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
if msg is None:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
raise ConnectionError(
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
)
self._process_response(motor, msg)
return self._get_cached_value(motor, data_name)
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the cache."""
state = self._last_known_states[motor]
mapping: dict[str, Any] = {
"Present_Position": state["position"],
"Present_Velocity": state["velocity"],
"Present_Torque": state["torque"],
"Temperature_MOS": state["temp_mos"],
"Temperature_Rotor": state["temp_rotor"],
}
if data_name not in mapping:
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
@check_if_not_connected
def write(
self,
data_name: str,
motor: str,
value: Value,
) -> None:
"""
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
elif data_name == "Goal_Position":
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
self._mit_control(motor, kp, kd, float(value), 0.0, 0.0)
else:
raise ValueError(f"Writing {data_name} not supported in MIT mode")
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
) -> dict[str, Value]:
"""
Read the same value from multiple motors simultaneously.
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._get_cached_value(motor, data_name)
return result
def sync_read_all_states(
self,
motors: str | list[str] | None = None,
*,
num_retry: int = 0,
) -> dict[str, MotorState]:
"""
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
Returns:
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
"""
target_motors = self._get_motors_list(motors)
self._batch_refresh(target_motors)
result = {}
for motor in target_motors:
result[motor] = self._last_known_states[motor].copy()
return result
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
# Collect responses
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC)
# Update cache
for motor in motors:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg:
self._process_response(motor, msg)
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
@check_if_not_connected
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
self._gains[motor][key] = float(val)
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
kp = self._gains[motor]["kp"]
kd = self._gains[motor]["kd"]
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
msg = can.Message(
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
)
self.canbus.send(msg)
precise_sleep(PRECISE_TIMEOUT_SEC)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Step 2: Collect responses and update state cache
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
else:
# Fall back to individual writes
for motor, value in values.items():
self.write(data_name, motor, value)
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration data from motors."""
# Damiao motors don't store calibration internally
# Return existing calibration or empty dict
return self.calibration if self.calibration else {}
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration data to motors."""
# Damiao motors don't store calibration internally
# Just cache it in memory
if cache:
self.calibration = calibration_dict
def record_ranges_of_motion(
self,
motors: str | list[str] | None = None,
display_values: bool = True,
) -> tuple[dict[str, Value], dict[str, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
Move the joints by hand (with torque disabled) while the method streams live positions.
Press Enter to finish.
"""
target_motors = self._get_motors_list(motors)
self.disable_torque(target_motors)
time.sleep(LONG_TIMEOUT_SEC)
start_positions = self.sync_read("Present_Position", target_motors)
mins = start_positions.copy()
maxes = start_positions.copy()
print("\nMove joints through their full range of motion. Press ENTER when done.")
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", target_motors)
for motor in target_motors:
if motor in positions:
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
if display_values:
print("\n" + "=" * 50)
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
print("-" * 50)
for motor in target_motors:
if motor in positions:
print(
f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}"
)
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
move_cursor_up(len(target_motors) + 4)
time.sleep(LONG_TIMEOUT_SEC)
self.enable_torque(target_motors)
for motor in target_motors:
if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5):
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
return mins, maxes
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
"""Convert motor specification to list of motor names."""
if motors is None:
return list(self.motors.keys())
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors
else:
raise TypeError(f"Invalid motors type: {type(motors)}")
def _get_motor_id(self, motor: NameOrID) -> int:
"""Get CAN ID for a motor."""
if isinstance(motor, str):
if motor in self.motors:
return self.motors[motor].id
else:
raise ValueError(f"Unknown motor: {motor}")
else:
return motor
def _get_motor_name(self, motor: NameOrID) -> str:
"""Get motor name from name or ID."""
if isinstance(motor, str):
return motor
else:
for name, m in self.motors.items():
if m.id == motor:
return name
raise ValueError(f"Unknown motor ID: {motor}")
def _get_motor_recv_id(self, motor: NameOrID) -> int:
"""Get motor recv_id from name or ID."""
motor_name = self._get_motor_name(motor)
motor_obj = self.motors.get(motor_name)
if motor_obj and motor_obj.recv_id is not None:
return motor_obj.recv_id
else:
raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).")
@cached_property
def is_calibrated(self) -> bool:
"""Check if motors are calibrated."""
return bool(self.calibration)

View File

@@ -0,0 +1,209 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
# Motor type definitions
class MotorType(IntEnum):
DM3507 = 0
DM4310 = 1
DM4310_48V = 2
DM4340 = 3
DM4340_48V = 4
DM6006 = 5
DM8006 = 6
DM8009 = 7
DM10010L = 8
DM10010 = 9
DMH3510 = 10
DMH6215 = 11
DMG6220 = 12
# Control modes
class ControlMode(IntEnum):
MIT = 1
POS_VEL = 2
VEL = 3
TORQUE_POS = 4
# Motor variable IDs (RID)
class MotorVariable(IntEnum):
UV_VALUE = 0
KT_VALUE = 1
OT_VALUE = 2
OC_VALUE = 3
ACC = 4
DEC = 5
MAX_SPD = 6
MST_ID = 7
ESC_ID = 8
TIMEOUT = 9
CTRL_MODE = 10
DAMP = 11
INERTIA = 12
HW_VER = 13
SW_VER = 14
SN = 15
NPP = 16
RS = 17
LS = 18
FLUX = 19
GR = 20
PMAX = 21
VMAX = 22
TMAX = 23
I_BW = 24
KP_ASR = 25
KI_ASR = 26
KP_APR = 27
KI_APR = 28
OV_VALUE = 29
GREF = 30
DETA = 31
V_BW = 32
IQ_C1 = 33
VL_C1 = 34
CAN_BR = 35
SUB_VER = 36
U_OFF = 50
V_OFF = 51
K1 = 52
K2 = 53
M_OFF = 54
DIR = 55
P_M = 80
XOUT = 81
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS = {
MotorType.DM3507: (12.5, 30, 10),
MotorType.DM4310: (12.5, 30, 10),
MotorType.DM4310_48V: (12.5, 50, 10),
MotorType.DM4340: (12.5, 8, 28),
MotorType.DM4340_48V: (12.5, 10, 28),
MotorType.DM6006: (12.5, 45, 20),
MotorType.DM8006: (12.5, 45, 40),
MotorType.DM8009: (12.5, 45, 54),
MotorType.DM10010L: (12.5, 25, 200),
MotorType.DM10010: (12.5, 20, 200),
MotorType.DMH3510: (12.5, 280, 1),
MotorType.DMH6215: (12.5, 45, 10),
MotorType.DMG6220: (12.5, 45, 10),
}
# Motor model names
MODEL_NAMES = {
MotorType.DM3507: "dm3507",
MotorType.DM4310: "dm4310",
MotorType.DM4310_48V: "dm4310_48v",
MotorType.DM4340: "dm4340",
MotorType.DM4340_48V: "dm4340_48v",
MotorType.DM6006: "dm6006",
MotorType.DM8006: "dm8006",
MotorType.DM8009: "dm8009",
MotorType.DM10010L: "dm10010l",
MotorType.DM10010: "dm10010",
MotorType.DMH3510: "dmh3510",
MotorType.DMH6215: "dmh6215",
MotorType.DMG6220: "dmg6220",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"dm3507": 65536,
"dm4310": 65536,
"dm4310_48v": 65536,
"dm4340": 65536,
"dm4340_48v": 65536,
"dm6006": 65536,
"dm8006": 65536,
"dm8009": 65536,
"dm10010l": 65536,
"dm10010": 65536,
"dmh3510": 65536,
"dmh6215": 65536,
"dmg6220": 65536,
}
# CAN baudrates supported by Damiao motors
AVAILABLE_BAUDRATES = [
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
1000000, # 4: 1 mbps (default for OpenArms)
2000000, # 5: 2 mbps
2500000, # 6: 2.5 mbps
3200000, # 7: 3.2 mbps
4000000, # 8: 4 mbps
5000000, # 9: 5 mbps
]
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 1000
# OpenArms specific configurations
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms has 7 DOF per arm (14 total for dual arm)
OPENARMS_ARM_MOTOR_IDS = {
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
}
OPENARMS_GRIPPER_MOTOR_IDS = {
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
}
# Default motor types for OpenArms
OPENARMS_DEFAULT_MOTOR_TYPES = {
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
}
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_REFRESH = 0xCC
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF

View File

@@ -22,9 +22,8 @@ import logging
from copy import deepcopy
from enum import Enum
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from ..encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
@@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
return data
class DynamixelMotorsBus(MotorsBus):
class DynamixelMotorsBus(SerialMotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:
@@ -182,10 +181,10 @@ class DynamixelMotorsBus(MotorsBus):
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
drive_mode=int(drive_modes[motor]),
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
)
return calibration
@@ -199,15 +198,15 @@ class DynamixelMotorsBus(MotorsBus):
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -236,7 +235,7 @@ class DynamixelMotorsBus(MotorsBus):
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings = {}
half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -259,6 +258,6 @@ class DynamixelMotorsBus(MotorsBus):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return None
return {id_: data[0] for id_, data in data_list.items()}

View File

@@ -17,9 +17,8 @@ from copy import deepcopy
from enum import Enum
from pprint import pformat
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
@@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(MotorsBus):
class FeetechMotorsBus(SerialMotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
@@ -127,7 +126,7 @@ class FeetechMotorsBus(MotorsBus):
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
@@ -263,9 +262,9 @@ class FeetechMotorsBus(MotorsBus):
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
)
return calibration
@@ -285,7 +284,7 @@ class FeetechMotorsBus(MotorsBus):
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings = {}
half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -293,18 +292,18 @@ class FeetechMotorsBus(MotorsBus):
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None:
def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None:
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry)
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor_id, 0, num_retry=num_retry)
self._write(addr, length, motor, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
@@ -335,7 +334,7 @@ class FeetechMotorsBus(MotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list = {}
data_list: dict[int, int] = {}
status_length = 6
@@ -415,7 +414,7 @@ class FeetechMotorsBus(MotorsBus):
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return None
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:

View File

@@ -19,8 +19,11 @@
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
from __future__ import annotations
import abc
import logging
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -41,6 +44,81 @@ Value: TypeAlias = int | float
logger = logging.getLogger(__name__)
class MotorsBusBase(abc.ABC):
"""
Base class for all motor bus implementations.
This is a minimal interface that all motor buses must implement, regardless of their
communication protocol (serial, CAN, etc.).
"""
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
@abc.abstractmethod
def connect(self, handshake: bool = True) -> None:
"""Establish connection to the motors."""
pass
@abc.abstractmethod
def disconnect(self, disable_torque: bool = True) -> None:
"""Disconnect from the motors."""
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if connected to the motors."""
pass
@abc.abstractmethod
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor."""
pass
@abc.abstractmethod
def write(self, data_name: str, motor: str, value: Value) -> None:
"""Write a value to a single motor."""
pass
@abc.abstractmethod
def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]:
"""Read a value from multiple motors."""
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
pass
@abc.abstractmethod
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration parameters from the motors."""
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration parameters to the motors."""
pass
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
@@ -97,18 +175,21 @@ class Motor:
id: int
model: str
norm_mode: MotorNormMode
motor_type_str: str | None = None
recv_id: int | None = None
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
is_open: bool
baudrate: int
packet_start_time: float
packet_timeout: float
tx_time_per_byte: float
is_using: bool
port_name: str
ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def openPort(self): ...
def closePort(self): ...
@@ -161,19 +242,22 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
port: str
ph: PortHandler
start_address: int
data_length: int
last_result: bool
is_param_changed: bool
param: list
data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
@@ -186,15 +270,17 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
port: str
ph: PortHandler
start_address: int
data_length: int
is_param_changed: bool
param: list
data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
@@ -203,15 +289,15 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ...
class MotorsBus(abc.ABC):
class SerialMotorsBus(MotorsBusBase):
"""
A MotorsBus allows to efficiently read and write to the attached motors.
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
It represents several motors daisy-chained together and connected through a serial port.
There are currently two implementations of this abstract class:
There are currently two implementations of this class:
- DynamixelMotorsBus
- FeetechMotorsBus
Note: This class may evolve in the future should we add support for other types of bus.
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
@@ -260,9 +346,7 @@ class MotorsBus(abc.ABC):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
super().__init__(port, motors, calibration)
self.port_handler: PortHandler
self.packet_handler: PacketHandler
@@ -323,7 +407,7 @@ class MotorsBus(abc.ABC):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> int:
def _get_motor_model(self, motor: NameOrID) -> str:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
@@ -331,17 +415,19 @@ class MotorsBus(abc.ABC):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
if motors is None:
return list(self.motors)
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors.copy()
elif isinstance(motors, int):
return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
@@ -532,7 +618,7 @@ class MotorsBus(abc.ABC):
self.set_baudrate(self.default_baudrate)
@abc.abstractmethod
def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]:
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
pass
@abc.abstractmethod
@@ -545,13 +631,13 @@ class MotorsBus(abc.ABC):
pass
@abc.abstractmethod
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors.
Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM).
Args:
motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a
list of names or `None` to affect every registered motor. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
@@ -563,18 +649,19 @@ class MotorsBus(abc.ABC):
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors.
Args:
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
"""
pass
@contextmanager
def torque_disabled(self, motors: int | str | list[str] | None = None):
def torque_disabled(self, motors: str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -651,24 +738,19 @@ class MotorsBus(abc.ABC):
"""
pass
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared.
Args:
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
for motor in motors:
for motor in motor_names:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
@@ -677,7 +759,9 @@ class MotorsBus(abc.ABC):
self.calibration = {}
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one
@@ -687,17 +771,12 @@ class MotorsBus(abc.ABC):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns:
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
dict[str, Value]: Mapping *motor name → written homing offset*.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
@@ -709,8 +788,8 @@ class MotorsBus(abc.ABC):
pass
def record_ranges_of_motion(
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[str, Value], dict[str, Value]]:
"""Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -722,30 +801,25 @@ class MotorsBus(abc.ABC):
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns:
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motors, normalize=False)
positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motors:
for motor in motor_names:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed():
@@ -753,9 +827,9 @@ class MotorsBus(abc.ABC):
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 3)
move_cursor_up(len(motor_names) + 3)
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -878,12 +952,12 @@ class MotorsBus(abc.ABC):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return
return None
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return
return None
return model_number
@@ -930,12 +1004,13 @@ class MotorsBus(abc.ABC):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
id_value = self._decode_sign(data_name, {id_: value})
decoded = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
id_value = self._normalize(id_value)
normalized = self._normalize(decoded)
return normalized[id_]
return id_value[id_]
return decoded[id_]
def _read(
self,
@@ -946,7 +1021,7 @@ class MotorsBus(abc.ABC):
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
) -> tuple[int, int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -996,13 +1071,14 @@ class MotorsBus(abc.ABC):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data:
value = self._unnormalize({id_: value})[id_]
int_value = self._unnormalize({id_: value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
@@ -1036,7 +1112,7 @@ class MotorsBus(abc.ABC):
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
motors: NameOrID | Sequence[NameOrID] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
@@ -1045,7 +1121,7 @@ class MotorsBus(abc.ABC):
Args:
data_name (str): Register name.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1066,16 +1142,17 @@ class MotorsBus(abc.ABC):
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
ids_values, _ = self._sync_read(
raw_ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
ids_values = self._decode_sign(data_name, ids_values)
decoded = self._decode_sign(data_name, raw_ids_values)
if normalize and data_name in self.normalized_data:
ids_values = self._normalize(ids_values)
normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
return {self._id_to_name(id_): value for id_, value in decoded.items()}
def _sync_read(
self,
@@ -1147,21 +1224,24 @@ class MotorsBus(abc.ABC):
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in raw_ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data:
ids_values = self._unnormalize(ids_values)
int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._encode_sign(data_name, ids_values)
int_ids_values = self._encode_sign(data_name, int_ids_values)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
def _sync_write(
self,
@@ -1194,3 +1274,7 @@ class MotorsBus(abc.ABC):
for id_, value in ids_values.items():
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
# Backward compatibility alias
MotorsBus: TypeAlias = SerialMotorsBus

View File

@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- Either:
@@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig):
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
@@ -98,7 +89,6 @@ class ACTConfig(PreTrainedConfig):
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"AUDIO": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
@@ -109,10 +99,6 @@ class ACTConfig(PreTrainedConfig):
vision_backbone: str = "resnet18"
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
replace_final_stride_with_dilation: int = False
# Audio backbone.
audio_backbone: str = vision_backbone
pretrained_backbone_weights_audio: str | None = None
replace_final_stride_with_dilation_audio: int = False
# Transformer layers.
pre_norm: bool = False
dim_model: int = 512
@@ -175,10 +161,8 @@ class ACTConfig(PreTrainedConfig):
return None
def validate_features(self) -> None:
if not (self.image_features or self.audio_features) and not self.env_state_feature:
raise ValueError(
"You must provide at least one image/audio or the environment state among the inputs."
)
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:

View File

@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class ACTPolicy(PreTrainedPolicy):
@@ -106,8 +106,6 @@ class ACTPolicy(PreTrainedPolicy):
"""
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
if self.config.temporal_ensemble_coeff is not None:
actions = self.predict_action_chunk(batch)
action = self.temporal_ensembler.update(actions)
@@ -333,26 +331,12 @@ class ACT(nn.Module):
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Backbone for audio feature extraction.
if self.config.audio_features:
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
weights=config.pretrained_backbone_weights_audio,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.audio_backbone = IntermediateLayerGetter(
audio_backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels), (audio_feature)].
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear(
self.config.robot_state_feature.shape[0], config.dim_model
@@ -366,10 +350,6 @@ class ACT(nn.Module):
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
if self.config.audio_features:
self.encoder_audio_feat_input_proj = nn.Conv2d(
audio_backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
n_1d_tokens = 1 # for the latent
if self.config.robot_state_feature:
@@ -379,8 +359,6 @@ class ACT(nn.Module):
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
if self.config.audio_features:
self.encoder_audio_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
@@ -505,21 +483,6 @@ class ACT(nn.Module):
encoder_in_tokens.extend(list(cam_features))
encoder_in_pos_embed.extend(list(cam_pos_embed))
if self.config.audio_features:
for audio in batch[OBS_AUDIO]:
audio_features = self.audio_backbone(audio)["feature_map"]
audio_pos_embed = self.encoder_audio_feat_pos_embed(audio_features).to(
dtype=audio_features.dtype
)
audio_features = self.encoder_audio_feat_input_proj(audio_features)
# Rearrange features to (sequence, batch, dim).
audio_features = einops.rearrange(audio_features, "b c h w -> (h w) b c")
audio_pos_embed = einops.rearrange(audio_pos_embed, "b c h w -> (h w) b c")
encoder_in_tokens.extend(list(audio_features))
encoder_in_pos_embed.extend(list(audio_pos_embed))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)

View File

@@ -17,11 +17,9 @@ from typing import Any
import torch
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
AudioProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -65,15 +63,6 @@ def make_act_pre_post_processors(
stats=dataset_stats,
device=config.device,
),
AudioProcessorStep(
output_height=224,
output_width=224,
output_channels=3,
input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION,
input_sample_rate=48000,
intermediate_sample_rate=16000,
n_fft=1024,
),
]
output_steps = [
UnnormalizerProcessorStep(

View File

@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig):
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@@ -73,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig):
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.

View File

@@ -239,8 +239,10 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
@@ -457,11 +459,10 @@ class SACPolicy(
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self):
"""Set up temperature parameter and initial log_alpha."""
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.temperature = self.log_alpha.exp().item()
class SACObservationEncoder(nn.Module):

View File

@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
loss_dict["losses_after_forward"] = losses.clone().mean().item()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims

View File

@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
@@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig):
is an alternative to using action repeats. If this is set to more than 1, then we require
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
approach of using multiple steps from the plan is not in the original implementation.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.

View File

@@ -106,7 +106,7 @@ def prepare_observation_for_inference(
This function takes a dictionary of NumPy arrays, performs necessary
preprocessing, and prepares it for model inference. The steps include:
1. Converting NumPy arrays to PyTorch tensors.
2. Normalizing and permuting image data and audio data (if any).
2. Normalizing and permuting image data (if any).
3. Adding a batch dimension to each tensor.
4. Moving all tensors to the specified compute device.
5. Adding task and robot type information to the dictionary.
@@ -129,9 +129,6 @@ def prepare_observation_for_inference(
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
elif "audio" in name:
observation[name] = observation[name].type(torch.float32)
observation[name] = observation[name].permute(1, 0).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)

View File

@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig):
current step and additional steps going back).
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
action_chunk_size: Action chunk size of each action prediction token.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .audio_processor import AudioProcessorStep
from .batch_processor import AddBatchDimensionProcessorStep
from .converters import (
batch_to_transition,
@@ -81,7 +80,6 @@ __all__ = [
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",
"AddTeleopEventsAsInfoStep",
"AudioProcessorStep",
"ComplementaryDataProcessorStep",
"batch_to_transition",
"create_transition",

View File

@@ -1,130 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from torch import Tensor
from torchaudio.functional import amplitude_to_DB
from torchaudio.transforms import MelSpectrogram, Resample
from torchvision.transforms import Compose, Lambda, Resize
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
from lerobot.utils.constants import OBS_AUDIO
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="audio_processor")
class AudioProcessorStep(ObservationProcessorStep):
"""
Processes audio waveform data into a mel-spectrogram image representation.
**Audio Processing:**
- Averages waveform data over all channels.
- Resamples the waveform to 16kHz.
- Converts the waveform to a mel-spectrogram.
- Converts the mel-spectrogram to decibels.
- Resizes the mel-spectrogram to 224×224.
- Converts the mel-spectrogram to a channel-first, normalized tensor.
Attributes:
output_height: Height of the output mel-spectrogram image in pixels.
output_width: Width of the output mel-spectrogram image in pixels.
output_channels: Number of channels in the output image (3 for RGB-like format).
input_audio_chunk_duration: Duration of the input audio chunk in seconds.
input_sample_rate: Original sample rate of the input audio in Hz.
intermediate_sample_rate: Reduced intermediate sample rate in Hz.
Downsampling improves the temporal resolution but reduces the frequency range.
n_fft: Size of the FFT window for spectrogram computation.
Increasing the window size increases the frequency resolution but decreases the temporal resolution.
hop_length: Number of samples between successive frames, computed automatically to match the output_width.
Decreasing the hop length increases the temporal resolution but decreases the frequency resolution.
n_mels: Number of mel filter banks, computed automatically to match the output_height.
Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution.
mel_spectrogram_transform: The complete audio processing pipeline.
"""
output_height: int = 224
output_width: int = 224
output_channels: int = 3
input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION
input_sample_rate: int = 48000
intermediate_sample_rate: int = 16000
n_fft: int = 1024
# Parameters computed from other parameters at initialization
hop_length: int = field(init=False)
n_mels: int = field(init=False)
mel_spectrogram_transform: Compose = field(init=False, repr=False)
def __post_init__(self):
self.hop_length = int(
self.intermediate_sample_rate * self.input_audio_chunk_duration
- self.n_fft // self.output_width
- 1
)
self.n_mels = self.output_height
self.mel_spectrogram_transform = Compose(
[
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate),
MelSpectrogram(
sample_rate=self.intermediate_sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
power=2, # Power spectrum
),
Lambda(
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
), # Convert to decibels
Resize(
(self.output_height, self.output_width)
), # Resize spectrogram to output_height×output_width
Lambda(
lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1)
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
]
)
def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
"""
Processes audio data contained in the provided observation.
"""
processed_obs = observation.copy()
# Process single audio observation
if OBS_AUDIO in processed_obs:
audio_data = processed_obs[OBS_AUDIO]
if isinstance(audio_data, Tensor) and audio_data.dim() == 3: # Batch, Channels, Samples
processed_obs[OBS_AUDIO] = self.mel_spectrogram_transform(audio_data)
# Process multiple audio observations
for key, value in processed_obs.items():
if (
key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 3
): # Batch, Channels, Samples
processed_obs[key] = self.mel_spectrogram_transform(value)
return processed_obs
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
return self._process_observation(observation)

View File

@@ -25,7 +25,7 @@ from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from .core import EnvTransition, PolicyAction
from .pipeline import (
@@ -88,8 +88,6 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
- State vectors (1D tensors).
- Single images (3D tensors).
- Dictionaries of multiple images (3D tensors).
- Single audio waveforms (2D tensors).
- Dictionaries of multiple audio waveforms (2D tensors).
"""
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
@@ -119,18 +117,6 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
# Process single audio observation - add batch dim if 2D
if OBS_AUDIO in observation:
audio_value = observation[OBS_AUDIO]
if isinstance(audio_value, Tensor) and audio_value.dim() == 2:
observation[OBS_AUDIO] = audio_value.unsqueeze(0)
# Process multiple audio observations - add batch dim if 2D
for key, value in observation.items():
if key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 2:
observation[key] = value.unsqueeze(0)
return observation
def transform_features(

View File

@@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
if ft != FeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
@@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
key=OBS_STATE,
type=FeatureType.STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
new_features[FeatureType.STATE] = state_feats
return new_features

View File

@@ -18,16 +18,18 @@
import math
import time
from dataclasses import dataclass
from typing import Any, Protocol, TypeVar, runtime_checkable
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
if TYPE_CHECKING:
from lerobot.teleoperators.teleoperator import Teleoperator
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import (
ComplementaryDataProcessorStep,
@@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol):
# Type variable constrained to Teleoperator subclasses that also implement events
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
def _check_teleop_with_events(teleop: Teleoperator) -> None:
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
"""
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
@@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
teleop_device: The teleoperator instance to get the action from.
"""
teleop_device: Teleoperator
teleop_device: "Teleoperator"
def complementary_data(self, complementary_data: dict) -> dict:
"""
@@ -312,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
class GripperPenaltyProcessorStep(ProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
@@ -327,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data: dict) -> dict:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data, which should contain
raw joint positions.
transition: The incoming environment transition.
Returns:
A new complementary data dictionary with the `discrete_penalty` key added.
The modified transition with the penalty added to complementary data.
"""
action = self.transition.get(TransitionKey.ACTION)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return complementary_data
return new_transition
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return complementary_data
return new_transition
# Gripper action is a PolicyAction at this stage
gripper_action = action[-1].item()
@@ -362,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Create new complementary data with penalty info
# Update complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
"""

View File

@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_LANGUAGE_TOKENS,
)
from lerobot.utils.import_utils import _transformers_available
@@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
"""
Extracts the subtask from the transition's complementary data.
Args:
transition: The environment transition.
Returns:
A list of subtask strings, or None if the subtask key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
subtask = complementary_data.get("subtask")
if subtask is None:
return None
# Standardize to a list of strings for the tokenizer
if isinstance(subtask, str):
return [subtask]
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
return subtask
return None
def observation(self, observation: RobotObservation) -> RobotObservation:
"""
Tokenizes the task description and adds it to the observation dictionary.
@@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize subtask if available
subtask = self.get_subtask(self.transition)
if subtask is not None:
tokenized_subtask = self._tokenize_text(subtask)
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_subtask = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_subtask.items()
}
# Add tokenized subtask to the observation
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
dtype=torch.bool
)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:

Some files were not shown because too many files have changed in this diff Show More