Compare commits

..

14 Commits

Author SHA1 Message Date
Pepijn
46e9e22b05 feat(eval): thread-safe policy copies for max_parallel_tasks > 1
eval_policy_all already supports running multiple task groups concurrently via
ThreadPoolExecutor, but policy.reset() was not thread-safe: all threads shared
the same policy object and its mutable state (action queues, temporal buffers).

Fix: each thread receives a shallow copy of the policy. copy.copy() creates a
new Python object whose _parameters dict is a shared reference — same tensor
storage, zero extra VRAM — while reset() rebinds per-episode state to fresh
objects per thread.

Caveat: ACT with temporal_ensemble_coeff is not safe with this approach (its
reset() mutates a shared sub-object). Keep max_parallel_tasks=1 for that config.

For MetaWorld (50 tasks, no temporal ensembling), max_parallel_tasks=4 raises
GPU utilization from ~20% to ~60-80% with no additional VRAM cost.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-03 17:11:36 +02:00
Pepijn
b43f9ab048 feat(envs): lazy env init + AsyncVectorEnv as default for n_envs > 1
LiberoEnv and MetaworldEnv previously allocated GPU resources (EGL context,
OpenGL framebuffer) in __init__, before AsyncVectorEnv's fork(). Worker
processes inherited stale GPU handles, causing EGL_BAD_CONTEXT crashes on
first render.

Fix: defer OffScreenRenderEnv / MT1 construction to _ensure_env(), called on
first reset() or step() inside the worker subprocess. Each worker creates its
own clean context after fork().

Also fixes lerobot_eval.py:170 (add_envs_task TODO): replace with
env.call("task") which works with both SyncVectorEnv and AsyncVectorEnv.

AsyncVectorEnv is now the default for n_envs > 1; auto-downgraded to
SyncVectorEnv when n_envs=1 (no benefit, less overhead).

Expected speedup: ~15-20x for LIBERO Spatial with batch_size=50.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-03 17:10:10 +02:00
Pepijn
0045f88355 merge: resolve conflicts from main into refactor/benchmark-dispatch
Keep refactored dispatch pattern (no factory.py edits for new benchmarks).
Incorporate main's "Verifying your integration" section and class naming fix.

Made-with: Cursor
2026-04-03 14:49:36 +02:00
Pepijn
4dbbcca496 docs(benchmarks): add benchmark integration guide and standardize benchmark docs (#3270)
* docs(benchmarks): add benchmark integration guide and standardize benchmark docs

Add a comprehensive guide for adding new benchmarks to LeRobot, and
refactor the existing LIBERO and Meta-World docs to follow the new
standardized template.

Made-with: Cursor

* docs(benchmarks): clean up adding-benchmarks guide for clarity

Rewrite for simpler language, better structure, and easier navigation.
Move quick-reference table to the top, fold eval explanation into
architecture section, condense the doc template to a bulleted outline.

Made-with: Cursor

* fix link

* fix task count

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/metaworld.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* docs(benchmarks): add verification checklist to adding-benchmarks guide

Made-with: Cursor

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
2026-04-03 14:44:53 +02:00
Pepijn
89ce91f69f Merge branch 'docs/adding-benchmarks-guide' into refactor/benchmark-dispatch 2026-04-03 13:56:49 +02:00
Pepijn
90e614f6b9 fix task count 2026-04-03 13:48:37 +02:00
Pepijn
ff4f860e5d fix link 2026-04-03 13:47:17 +02:00
Pepijn
6f2823bfc4 merge: resolve conflicts with docs/adding-benchmarks-guide
Incorporate cleaner writing from the docs branch while reflecting the
refactored dispatch pattern (no factory.py edits needed for new benchmarks).

Made-with: Cursor
2026-04-03 13:45:12 +02:00
Pepijn
77415559b8 docs(benchmarks): clean up adding-benchmarks guide for clarity
Rewrite for simpler language, better structure, and easier navigation.
Move quick-reference table to the top, fold eval explanation into
architecture section, condense the doc template to a bulleted outline.

Made-with: Cursor
2026-04-03 13:36:16 +02:00
Pepijn
24d9b74d81 refactor(envs): move dispatch logic from factory into EnvConfig subclasses
Replace hardcoded if/elif chains in factory.py with create_envs() and
get_env_processors() methods on EnvConfig. New benchmarks now only need
to register a config subclass — no factory.py edits required.

Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic.

Made-with: Cursor
2026-04-03 13:23:44 +02:00
Pepijn
508358749a docs(benchmarks): add benchmark integration guide and standardize benchmark docs
Add a comprehensive guide for adding new benchmarks to LeRobot, and
refactor the existing LIBERO and Meta-World docs to follow the new
standardized template.

Made-with: Cursor
2026-04-02 20:43:31 +02:00
Pepijn
818892a38b feat(dagger): Add HIL/Dagger/HG-Dagger/RaC style data collection (#2833)
* feat: HIL data collection, RTC interpolator, and action queue improvements

- Add Human-in-the-Loop (HIL) data collection examples (sync + RTC)
- Add HIL data collection documentation
- Add ActionInterpolator for smoother policy control at higher rates
- Integrate interpolator into lerobot-record and eval_with_real_robot
- Add action queue clear() and get_processed_left_over() methods
- Add rtc/__init__.py for cleaner imports

* docs: expand Related Work section with paper summaries

* fix: only record dataset frames at original fps, not at interpolated rate

The interpolator speeds up robot control (e.g. 2x) but dataset frames
should still be recorded at the original fps. Interpolated-only
iterations now only send actions to the robot without writing to the
dataset.

* refactor: merge HIL sync and RTC scripts into single file with --rtc.enabled toggle

Combines hil_data_collection.py and hil_data_collection_rtc.py into one
script. RTC is toggled via --rtc.enabled=true (defaults to off for sync
inference). Deletes the separate hil_data_collection_rtc.py and updates
docs to reflect the single-script usage.

* test: add ActionInterpolator test suite (29 tests)

Covers constructor validation, passthrough (multiplier=1), 2x and 3x
interpolation with exact value checks, reset/episode boundaries,
control interval calculation, multi-dim actions, and simulated
control loop integration.

* test: add ActionQueue + ActionInterpolator integration tests

Verifies the interpolator doesn't interfere with RTC's leftover chunk
tracking: queue consumption rate matches base fps regardless of
multiplier, get_left_over/get_processed_left_over only change on
queue.get(), merge preserves smooth interpolation across chunks,
and interpolator reset is independent of queue state.

* feat: register SO follower/leader configs in HIL script

Adds SOFollowerRobotConfig and SOLeaderTeleopConfig imports so
SO100/SO101 robots can be used via --robot.type=so_follower
and --teleop.type=so_leader. Updates docs accordingly.

Made-with: Cursor

* docs: remove em dashes from HIL documentation

Made-with: Cursor

* refactor: rename examples/rac to examples/hil

Updates directory name and all references in docs and script docstrings.

Made-with: Cursor

* fix: encorperate pr feedback comments

* refactor(tests): enhance ActionInterpolator test structure and add detailed docstrings

* feedback pr and test fix

* fix(test): pass correct real_delay in interpolator delay test

The test was passing real_delay=0 and relying on _check_delays to
silently override it with the index-based diff. Now passes real_delay=3
to match the 3 actions consumed during the simulated inference period.


* fix pr feedback

* ordering

* update hil script

* fix

* default name

* fix(bi_openarm): use kw_only=True to fix dataclass field ordering

BiOpenArmFollowerConfig overrides `id` with a default, making it
positional in the child — non-default `left_arm_config` then follows a
default field, which Python dataclasses forbid. Adding kw_only=True
(matching the parent RobotConfig) removes positional constraints.

Made-with: Cursor

* style: format long line in hil_data_collection.py

Made-with: Cursor

* pr feedback

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
2026-04-02 19:53:59 +02:00
Pepijn
66fef25ded docs(toctree): add Benchmarks section for LIBERO and Meta-World (#3268)
* docs(toctree): add Benchmarks section for LIBERO and Meta-World

Move LIBERO and Meta-World pages out of the Simulation section into a
dedicated Benchmarks section so benchmark-specific docs are easier to
find and the Simulation section stays focused on environment hubs.

Made-with: Cursor

* docs(toctree): move IsaacLab Arena into Benchmarks section

Include NVIDIA IsaacLab Arena Environments alongside LIBERO and
Meta-World in the Benchmarks section.

Made-with: Cursor
2026-04-02 19:52:39 +02:00
Pepijn
2cf08b7a4b Add create reward visualization (#3155)
* Add create reward visualization and multimodal analysis tool

* add example for creating progress video for sarm

* nit

* precommit

* refactor: address review comments on create_progress_videos.py

- Add shebang and Apache 2.0 license header
- Replace hardcoded absolute OUTPUT_DIR with relative default (./progress_videos)
- Add argparse CLI (--repo-id, --episode, --camera-key, --output-dir, --gif)
- Wrap entrypoint in def main()
- Replace all print() with logging
- Use logging.error/warning instead of traceback.print_exc
- Release VideoCapture via try/finally; consolidate triple-open into single seek
- Eliminate intermediate clip file: seek directly via CAP_PROP_POS_MSEC
- Make MP4 the default output, GIF opt-in via --gif flag
- Add return types to all functions
- Add Args/Returns docstrings
- Use descriptive variable names throughout

Made-with: Cursor

* refactor: move create_progress_videos.py to examples/dataset/ for consistency

Made-with: Cursor

* refactor: address PR review comments on create_progress_videos.py

- Replace Unicode ellipsis and multiplication sign with ASCII equivalents
- Fix step numbering from 1-5 to 1-4 (only 4 actual steps)
- Move frame_width reading into convert_mp4_to_gif
- Remove unused text_height variable

Made-with: Cursor
2026-04-02 16:58:07 +02:00
89 changed files with 4346 additions and 5610 deletions

View File

@@ -1,219 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from soundfile import read
from lerobot.microphones.configs import MicrophoneConfig
from lerobot.microphones.portaudio import PortAudioMicrophone, PortAudioMicrophoneConfig
from lerobot.microphones.utils import (
async_microphones_start_recording,
async_microphones_stop_recording,
make_microphones_from_configs,
)
from lerobot.utils.robot_utils import (
precise_sleep,
)
def main(
microphones_configs: dict[str, MicrophoneConfig],
audio_chunks_number: int,
audio_chunks_duration: float,
repetitions: int,
multiprocessing: bool = False,
):
recording_dir = Path("outputs/audio_benchmark")
recording_dir.mkdir(parents=True, exist_ok=True)
# Create microphones
microphones = make_microphones_from_configs(microphones_configs)
# Connect microphones
for microphone in microphones.values():
microphone.connect()
all_audio_chunks = []
for i in range(repetitions):
print(f"Repetition {i + 1}/{repetitions}...")
# Create audio chunks
audio_chunks = {}
for microphone_key in microphones:
audio_chunks.update({microphone_key: []})
# Start recording
async_microphones_start_recording(
microphones,
output_files=[
recording_dir / f"{microphone_key}_recording_{i}.wav" for microphone_key in microphones
],
multiprocessing=multiprocessing,
)
# Record audio chunks
for j in range(audio_chunks_number):
precise_sleep(audio_chunks_duration)
for microphone_key, microphone in microphones.items():
audio_chunk = microphone.read()
print(f"{microphone_key} - repetition {i} - chunk {j} - samples {audio_chunk.shape[0]}")
audio_chunks[microphone_key].append(audio_chunk)
# Stop recording
async_microphones_stop_recording(microphones)
for microphone_key in microphones:
audio_chunks[microphone_key] = np.concatenate(audio_chunks[microphone_key], axis=0)
all_audio_chunks.append(audio_chunks)
# Disconnect microphones
for microphone in microphones.values():
microphone.disconnect()
# Compute statistics
cmap = plt.get_cmap("tab10")
_, ax = plt.subplots(nrows=repetitions, ncols=len(microphones))
chunk_length = np.zeros((repetitions, len(microphones)))
record_length = np.zeros((repetitions, len(microphones)))
for i in range(repetitions):
for j, (microphone_key, microphone) in enumerate(microphones.items()):
# Get recorded audio chunks
recorded_audio_chunks = all_audio_chunks[i][microphone_key]
# Load recorded file
recorded_data, _ = read(recording_dir / f"{microphone_key}_recording_{i}.wav")
if recorded_data.ndim == 1:
recorded_data = np.expand_dims(recorded_data, axis=1)
record_length[i, j] = recorded_data.shape[0]
chunk_length[i, j] = recorded_audio_chunks.shape[0]
for k, (chunk_data, record_data) in enumerate(
zip(recorded_audio_chunks.T, recorded_data.T, strict=False)
):
# Plot audio chunks and recorded data
ax[i, j].plot(
np.arange(0, len(chunk_data)) / microphone.sample_rate,
chunk_data,
label=f"audio chunks - channel {k}",
color=cmap(2 * k),
)
ax[i, j].plot(
np.arange(0, len(record_data)) / microphone.sample_rate,
record_data,
label=f"recorded data - channel {k}",
linestyle="dashed",
color=cmap(2 * k + 1),
)
# Plot absolute difference (errors should be located at the end of the recordings)
if recorded_data.shape[0] - recorded_audio_chunks.shape[0] > 0:
chunk_data = np.append(
chunk_data, np.zeros(int(recorded_data.shape[0] - recorded_audio_chunks.shape[0]))
)
else:
record_data = np.append(
record_data, np.zeros(int(-recorded_data.shape[0] + recorded_audio_chunks.shape[0]))
)
ax[i, j].plot(
np.arange(0, len(record_data)) / microphone.sample_rate,
np.abs(chunk_data - record_data),
label=f"differences - channel {k}",
color="red",
linestyle="dotted",
)
ax[i, j].set_title(f"{microphone_key} - repetition {i}")
ax[i, j].legend()
plt.show()
# Print statistics
differences = record_length - chunk_length
for i, (microphone_key, microphone) in enumerate(microphones.items()):
print(
f"Average recorded duration for {microphone_key} : {np.mean(record_length[:, i]) / microphone.sample_rate:.3f} seconds"
)
print(
f"Average chunk duration for {microphone_key} : {np.mean(chunk_length[:, i]) / microphone.sample_rate:.3f} seconds"
)
print(f"Average difference for {microphone_key} : {np.mean(differences[:, i]):.3f} samples")
print(
f"Average difference for {microphone_key} : {np.mean(differences[:, i]) / microphone.sample_rate:.3f} seconds"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--microphones_indices",
type=int,
nargs="+",
default=[microphone["index"] for microphone in PortAudioMicrophone.find_microphones()],
)
parser.add_argument(
"--microphones_sample_rate",
type=float,
nargs="+",
default=[None] * len(PortAudioMicrophone.find_microphones()),
)
parser.add_argument(
"--microphones_channels",
type=int,
nargs="+",
default=[None] * len(PortAudioMicrophone.find_microphones()),
)
parser.add_argument("--audio_chunks_number", type=int, default=2)
parser.add_argument(
"--audio_chunks_duration",
type=float,
default=1.0,
)
parser.add_argument(
"--repetitions",
type=int,
default=2,
)
parser.add_argument(
"--multiprocessing",
action="store_true",
)
args = vars(parser.parse_args())
args["microphones_configs"] = {}
for index, sample_rate, channels in zip(
args["microphones_indices"],
args["microphones_sample_rate"],
args["microphones_channels"],
strict=False,
):
microphone_config = PortAudioMicrophoneConfig(
microphone_index=index,
sample_rate=sample_rate,
channels=channels,
)
args["microphones_configs"].update({f"microphone_{index}": microphone_config})
args.pop("microphones_indices")
args.pop("microphones_sample_rate")
args.pop("microphones_channels")
main(**args)

View File

@@ -1,137 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import numpy as np
import soundfile as sf
from lerobot.microphones.configs import MicrophoneConfig
from lerobot.microphones.touchlab import TouchLabSensorConfig
from lerobot.microphones.utils import (
async_microphones_start_recording,
async_microphones_stop_recording,
make_microphones_from_configs,
)
from lerobot.utils.robot_utils import (
precise_sleep,
)
def main(
sensors_configs: dict[str, MicrophoneConfig],
multiprocessing: bool = False,
):
recording_dir = Path("outputs/tactile_benchmark")
recording_dir.mkdir(parents=True, exist_ok=True)
# Create microphones
sensors = make_microphones_from_configs(sensors_configs)
# Connect microphones
for sensor in sensors.values():
sensor.connect()
# Create audio chunks
data_chunks = {}
for sensor_key in sensors:
data_chunks.update({sensor_key: []})
# Start recording
async_microphones_start_recording(
sensors,
output_files=[recording_dir / f"{sensor_key}_recording.wav" for sensor_key in sensors],
multiprocessing=multiprocessing,
)
# Record audio chunks
precise_sleep(10.0)
for sensor_key, sensor in sensors.items():
data_chunk = sensor.read()
print(f"{sensor_key} - samples {data_chunk.shape[0]}")
data_chunks[sensor_key].append(data_chunk)
# Stop recording
async_microphones_stop_recording(sensors)
for sensor_key in sensors:
data_chunks[sensor_key] = np.concatenate(data_chunks[sensor_key], axis=0)
# Disconnect microphones
for sensor in sensors.values():
sensor.disconnect()
for sensor_key in sensors:
data, sample_rate = sf.read(recording_dir / f"{sensor_key}_recording.wav")
print(f"{sensor_key} - samples {data.shape[0]}")
print(f"{sensor_key} - sample rate {sample_rate}")
print(f"{sensor_key} - data {data}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--sensors_ports",
type=str,
nargs="+",
)
parser.add_argument(
"--sensors_baud_rate",
type=int,
nargs="+",
)
parser.add_argument(
"--sensors_sample_rate",
type=int,
nargs="+",
)
parser.add_argument(
"--sensors_channels",
type=int,
nargs="+",
)
parser.add_argument(
"--multiprocessing",
action="store_true",
)
args = vars(parser.parse_args())
args["sensors_configs"] = {}
for port, baud_rate, sample_rate, channels in zip(
args["sensors_ports"],
args["sensors_baud_rate"],
args["sensors_sample_rate"],
args["sensors_channels"],
strict=False,
):
if isinstance(channels, int):
channels = [channels]
sensor_config = TouchLabSensorConfig(
sensor_port=port,
baud_rate=baud_rate,
sample_rate=sample_rate,
channels=channels,
)
args["sensors_configs"].update({f"sensor_{port}": sensor_config})
args.pop("sensors_ports")
args.pop("sensors_baud_rate")
args.pop("sensors_sample_rate")
args.pop("sensors_channels")
main(**args)

View File

@@ -17,6 +17,8 @@
title: Train RL in Simulation
- local: multi_gpu_training
title: Multi GPU training
- local: hil_data_collection
title: Human In the Loop Data Collection
- local: peft_training
title: Training with PEFT (e.g., LoRA)
- local: rename_map
@@ -69,13 +71,17 @@
title: Environments from the Hub
- local: envhub_leisaac
title: Control & Train Robots in Sim (LeIsaac)
title: "Simulation"
- sections:
- local: adding_benchmarks
title: Adding a New Benchmark
- local: libero
title: LIBERO
- local: metaworld
title: Meta-World
- local: envhub_isaaclab_arena
title: NVIDIA IsaacLab Arena Environments
- local: libero
title: Using Libero
- local: metaworld
title: Using MetaWorld
title: "Simulation"
title: "Benchmarks"
- sections:
- local: introduction_processors
title: Introduction to Robot Processors

View File

@@ -0,0 +1,320 @@
# Adding a New Benchmark
This guide walks you through adding a new simulation benchmark to LeRobot. Follow the steps in order and use the existing benchmarks as templates.
A benchmark in LeRobot is a set of [Gymnasium](https://gymnasium.farama.org/) environments that wrap a third-party simulator (like LIBERO or Meta-World) behind a standard `gym.Env` interface. The `lerobot-eval` CLI then runs evaluation uniformly across all benchmarks.
## Existing benchmarks at a glance
Before diving in, here is what is already integrated:
| Benchmark | Env file | Config class | Tasks | Action dim | Processor |
| -------------- | ------------------- | ------------------ | ------------------- | ------------ | ---------------------------- |
| LIBERO | `envs/libero.py` | `LiberoEnv` | 130 across 5 suites | 7 | `LiberoProcessorStep` |
| Meta-World | `envs/metaworld.py` | `MetaworldEnv` | 50 (MT50) | 4 | None |
| IsaacLab Arena | Hub-hosted | `IsaaclabArenaEnv` | Configurable | Configurable | `IsaaclabArenaProcessorStep` |
Use `src/lerobot/envs/libero.py` and `src/lerobot/envs/metaworld.py` as reference implementations.
## How it all fits together
### Data flow
During evaluation, data moves through four stages:
```
1. gym.Env ──→ raw observations (numpy dicts)
2. Preprocessing ──→ standard LeRobot keys + task description
(preprocess_observation, add_envs_task in envs/utils.py)
3. Processors ──→ env-specific then policy-specific transforms
(env_preprocessor, policy_preprocessor)
4. Policy ──→ select_action() ──→ action tensor
then reverse: policy_postprocessor → env_postprocessor → numpy action → env.step()
```
Most benchmarks only need to care about stage 1 (producing observations in the right format) and optionally stage 3 (if env-specific transforms are needed).
### Environment structure
`make_env()` returns a nested dict of vectorized environments:
```python
dict[str, dict[int, gym.vector.VectorEnv]]
# ^suite ^task_id
```
A single-task env (e.g. PushT) looks like `{"pusht": {0: vec_env}}`.
A multi-task benchmark (e.g. LIBERO) looks like `{"libero_spatial": {0: vec0, 1: vec1, ...}, ...}`.
### How evaluation runs
All benchmarks are evaluated the same way by `lerobot-eval`:
1. `make_env()` builds the nested `{suite: {task_id: VectorEnv}}` dict.
2. `eval_policy_all()` iterates over every suite and task.
3. For each task, it runs `n_episodes` rollouts via `rollout()`.
4. Results are aggregated hierarchically: episode, task, suite, overall.
5. Metrics include `pc_success` (success rate), `avg_sum_reward`, and `avg_max_reward`.
The critical piece: your env must return `info["is_success"]` on every `step()` call. This is how the eval loop knows whether a task was completed.
## What your environment must provide
LeRobot does not enforce a strict observation schema. Instead it relies on a set of conventions that all benchmarks follow.
### Env attributes
Your `gym.Env` must set these attributes:
| Attribute | Type | Why |
| -------------------- | ----- | ---------------------------------------------------- |
| `_max_episode_steps` | `int` | `rollout()` uses this to cap episode length |
| `task_description` | `str` | Passed to VLA policies as a language instruction |
| `task` | `str` | Fallback identifier if `task_description` is not set |
### Success reporting
Your `step()` and `reset()` must include `"is_success"` in the `info` dict:
```python
info = {"is_success": True} # or False
return observation, reward, terminated, truncated, info
```
### Observations
The simplest approach is to map your simulator's outputs to the standard keys that `preprocess_observation()` already understands. Do this inside your `gym.Env` (e.g. in a `_format_raw_obs()` helper):
| Your env should output | LeRobot maps it to | What it is |
| ------------------------- | -------------------------- | ------------------------------------- |
| `"pixels"` (single array) | `observation.image` | Single camera image, HWC uint8 |
| `"pixels"` (dict) | `observation.images.<cam>` | Multiple cameras, each HWC uint8 |
| `"agent_pos"` | `observation.state` | Proprioceptive state vector |
| `"environment_state"` | `observation.env_state` | Full environment state (e.g. PushT) |
| `"robot_state"` | `observation.robot_state` | Nested robot state dict (e.g. LIBERO) |
If your simulator uses different key names, you have two options:
1. **Recommended:** Rename them to the standard keys inside your `gym.Env` wrapper.
2. **Alternative:** Write an env processor to transform observations after `preprocess_observation()` runs (see step 4 below).
### Actions
Actions are continuous numpy arrays in a `gym.spaces.Box`. The dimensionality depends on your benchmark (7 for LIBERO, 4 for Meta-World, etc.). Policies adapt to different action dimensions through their `input_features` / `output_features` config.
### Feature declaration
Each `EnvConfig` subclass declares two dicts that tell the policy what to expect:
- `features` — maps feature names to `PolicyFeature(type, shape)` (e.g. action dim, image shape).
- `features_map` — maps raw observation keys to LeRobot convention keys (e.g. `"agent_pos"` to `"observation.state"`).
## Step by step
<Tip>
At minimum, you need two files: a **gym.Env wrapper** and an **EnvConfig
subclass** with a `create_envs()` override. Everything else is optional or
documentation. No changes to `factory.py` are needed.
</Tip>
### Checklist
| File | Required | Why |
| ---------------------------------------- | -------- | ------------------------------------------------------------ |
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark and its `create_envs()` for the CLI |
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
Create a `gym.Env` subclass that wraps the third-party simulator:
```python
class MyBenchmarkEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": <fps>}
def __init__(self, task_suite, task_id, ...):
super().__init__()
self.task = <task_name_string>
self.task_description = <natural_language_instruction>
self._max_episode_steps = <max_steps>
self.observation_space = spaces.Dict({...})
self.action_space = spaces.Box(low=..., high=..., shape=(...,), dtype=np.float32)
def reset(self, seed=None, **kwargs):
... # return (observation, info) — info must contain {"is_success": False}
def step(self, action: np.ndarray):
... # return (obs, reward, terminated, truncated, info) — info must contain {"is_success": <bool>}
def render(self):
... # return RGB image as numpy array
def close(self):
...
```
Also provide a factory function that returns the nested dict structure:
```python
def create_mybenchmark_envs(
task: str,
n_envs: int,
gym_kwargs: dict | None = None,
env_cls: type | None = None,
) -> dict[str, dict[int, Any]]:
"""Create {suite_name: {task_id: VectorEnv}} for MyBenchmark."""
...
```
See `create_libero_envs()` (multi-suite, multi-task) and `create_metaworld_envs()` (difficulty-grouped tasks) for reference.
### 2. The config (`src/lerobot/envs/configs.py`)
Register a config dataclass so users can select your benchmark with `--env.type=<name>`. Each config owns its environment creation and processor logic via two methods:
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
```python
@EnvConfig.register_subclass("<benchmark_name>")
@dataclass
class MyBenchmarkEnvConfig(EnvConfig):
task: str = "<default_task>"
fps: int = <fps>
obs_type: str = "pixels_agent_pos"
features: dict[str, PolicyFeature] = field(default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(<action_dim>,)),
})
features_map: dict[str, str] = field(default_factory=lambda: {
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
})
def __post_init__(self):
... # populate features based on obs_type
@property
def gym_kwargs(self) -> dict:
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
def create_envs(self, n_envs: int, use_async_envs: bool = False):
"""Override for multi-task benchmarks or custom env creation."""
from lerobot.envs.<benchmark> import create_<benchmark>_envs
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
def get_env_processors(self):
"""Override if your benchmark needs observation/action transforms."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
return (
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
```
Key points:
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
- `features` tells the policy what the environment produces.
- `features_map` maps raw observation keys to LeRobot convention keys.
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
### 3. Env processor (optional — `src/lerobot/processor/env_processor.py`)
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion). Define the processor step here and return it from `get_env_processors()` in your config (see step 2):
```python
@dataclass
@ProcessorStepRegistry.register(name="<benchmark>_processor")
class MyBenchmarkProcessorStep(ObservationProcessorStep):
def _process_observation(self, observation):
processed = observation.copy()
# your transforms here
return processed
def transform_features(self, features):
return features # update if shapes change
def observation(self, observation):
return self._process_observation(observation)
```
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
### 4. Dependencies (`pyproject.toml`)
Add a new optional-dependency group:
```toml
mybenchmark = ["my-benchmark-pkg==1.2.3", "lerobot[scipy-dep]"]
```
Pinning rules:
- **Always pin** benchmark packages to exact versions for reproducibility (e.g. `metaworld==3.0.0`).
- **Add platform markers** when needed (e.g. `; sys_platform == 'linux'`).
- **Pin fragile transitive deps** if known (e.g. `gymnasium==1.1.0` for Meta-World).
- **Document constraints** in your benchmark doc page.
Users install with:
```bash
pip install -e ".[mybenchmark]"
```
### 5. Documentation (`docs/source/<benchmark>.mdx`)
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
### 6. Table of contents (`docs/source/_toctree.yml`)
Add your benchmark to the "Benchmarks" section:
```yaml
- sections:
- local: libero
title: LIBERO
- local: metaworld
title: Meta-World
- local: envhub_isaaclab_arena
title: NVIDIA IsaacLab Arena Environments
- local: <your_benchmark>
title: <Your Benchmark Name>
title: "Benchmarks"
```
## Verifying your integration
After completing the steps above, confirm that everything works:
1. **Install** — `pip install -e ".[mybenchmark]"` and verify the dependency group installs cleanly.
2. **Smoke test env creation** — call `make_env()` with your config in Python, check that the returned dict has the expected `{suite: {task_id: VectorEnv}}` shape, and that `reset()` returns observations with the right keys.
3. **Run a full eval** — `lerobot-eval --env.type=<name> --env.task=<task> --eval.n_episodes=1 --eval.batch_size=1 --policy.path=<any_compatible_policy>` to exercise the full pipeline end-to-end.
4. **Check success detection** — verify that `info["is_success"]` flips to `True` when the task is actually completed. This is what the eval loop uses to compute success rates.
## Writing a benchmark doc page
Each benchmark `.mdx` page should include:
- **Title and description** — 1-2 paragraphs on what the benchmark tests and why it matters.
- **Links** — paper, GitHub repo, project website (if available).
- **Overview image or GIF.**
- **Available tasks** — table of task suites with counts and brief descriptions.
- **Installation** — `pip install -e ".[<benchmark>]"` plus any extra steps (env vars, system packages).
- **Evaluation** — recommended `lerobot-eval` command with `n_episodes` and `batch_size` for reproducible results. Include single-task and multi-task examples if applicable.
- **Policy inputs and outputs** — observation keys with shapes, action space description.
- **Recommended evaluation episodes** — how many episodes per task is standard.
- **Training** — example `lerobot-train` command.
- **Reproducing published results** — link to pretrained model, eval command, results table (if available).
See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for complete examples.

View File

@@ -151,7 +151,7 @@ observation = {
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`:
```python
from lerobot.envs.factory import make_env_pre_post_processors
@@ -159,47 +159,31 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg)
```
### Implementation in `envs/factory.py`
### How It Works
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
processor pipelines. The base class returns identity (no-op) processors by default.
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
# In your EnvConfig subclass:
def get_env_processors(self):
from lerobot.processor.pipeline import PolicyProcessorPipeline
return (
PolicyProcessorPipeline(steps=[MyProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
```
The factory function `make_env_pre_post_processors` simply delegates to this method,
with a special case for `XVLAConfig` policies which override the env processors entirely.
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:

View File

@@ -0,0 +1,269 @@
# Human-In-the-Loop Data Collection
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data (recovery movements and corrections) is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
---
## Why Human-In-the-Loop?
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
- Running the trained policy on the real robot
- Having a human intervene when the robot is about to fail
- Recording the human's recovery and correction as training data
- Fine-tuning the policy on the combined dataset
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
---
## How It Works
During a HIL session, the human operator follows this loop within each episode:
1. **Watch** the policy run autonomously
2. **Pause** when failure is imminent, the robot holds its position
3. **Take control** and teleoperate the robot back to a good state (recovery), then correct the behavior
4. **Return control to the policy**, the policy resumes autonomous execution
5. Repeat steps 24 as many times as needed during the episode
6. **End the episode** when the task is complete, save and move on to the next rollout
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Each round targets the current policy's failure modes.
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Policy v0 (trained on demos) │
│ ↓ │
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
│ ↓ │
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
│ ↓ │
│ ... (repeat until satisfactory performance) │
└─────────────────────────────────────────────────────────────────────────┘
```
---
## Hardware Requirements
### Teleoperator Requirements
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators in the current `examples/hil` scripts:**
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
## Script
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
| Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
---
## Step-by-Step Guide
### Step 1: Pre-train a Base Policy
First, train a policy on your demonstration dataset:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/demo-dataset \
--policy.type=pi0 \
--output_dir=outputs/pretrain \
--batch_size=32 \
--steps=50000
```
### Step 2: Collect HIL Data
**Standard inference (ACT, Diffusion Policy):**
```bash
python examples/hil/hil_data_collection.py \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--interpolation_multiplier=2
```
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
For models with high inference latency, enable RTC for smooth execution:
```bash
python examples/hil/hil_data_collection.py \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--interpolation_multiplier=3
```
**Controls (Conceptual):**
The interaction model is:
- **Pause input**: pause autonomous policy execution
- **Takeover input**: transfer control to the human operator and record intervention data
- **Return-to-policy input**: hand control back to the policy and continue the same episode
- **Episode control inputs**: save/re-record/stop/reset as needed
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
**The HIL Protocol:**
1. Watch the policy run autonomously (teleop is idle/free)
2. When you see imminent failure, trigger the **pause input**
- Policy stops
- Teleoperator moves to match robot position (torque enabled)
- No frames recorded during pause
3. Trigger the **takeover input** to take control
- Teleoperator torque disabled, free to move
- **Recovery**: Teleoperate the robot back to a good state
- **Correction**: Correct the behavior
- All movements are recorded
4. Trigger the **return-to-policy input**
- Policy resumes autonomous execution from the current state
- You can intervene again at any time (repeat steps 24)
5. End and save the episode when the task is complete (or episode time limit is reached)
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
7. Start the next episode
**Foot Pedal Setup (Linux):**
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
```bash
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
```
### Step 3: Fine-tune the Policy
Fine-tune on the **combined** dataset (`demo-dataset` + `hil-dataset` merged together):
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/hil-dataset \
--policy.type=pi0 \
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
--output_dir=outputs/hil_finetune \
--steps=20000
```
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
---
## Tips for Effective HIL Collection
### When to Intervene
Intervene when you see:
- Robot about to make an irreversible mistake
- Robot hesitating or showing uncertain behavior
- Robot deviating from the expected trajectory
### Recovery: Teleoperating Back to a Good State
During recovery, teleoperate the robot back to a state where:
- The robot is in a familiar, in-distribution configuration
- The current subtask can still be completed
- The recovery trajectory itself is informative training data
### Quality of Corrections
During correction:
- Provide **confident, clean** trajectories
- Complete the current subtask fully
- Don't overcorrect or add unnecessary movements
---
## Related Work
This HIL data collection approach builds on ideas from interactive imitation learning:
- **DAgger** (Ross et al., 2011) introduced the core idea: instead of only training on expert demonstrations, query the expert for corrections on states the _learner_ visits. This breaks the compounding-error cycle of standard behavioral cloning by iteratively collecting on-policy data.
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
```bibtex
@article{ross2011dagger,
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
year={2011}
}
@article{kelly2019hgdagger,
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
journal={arXiv preprint arXiv:1810.02890},
year={2019}
}
@article{hu2025rac,
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
journal={arXiv preprint arXiv:2509.07953},
year={2025}
}
@article{pi2025recap,
title={π0.6: a VLA That Learns From Experience},
author={Physical Intelligence},
year={2025}
}
```

View File

@@ -1,36 +1,61 @@
# LIBERO
**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots wont just be pretrained once in a factory, theyll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and its a key step toward building robots that become truly personalized helpers.
LIBERO is a benchmark designed to study **lifelong robot learning** — the idea that robots need to keep learning and adapting with their users over time, not just be pretrained once. It provides a set of standardized manipulation tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other's work.
- 📄 [LIBERO paper](https://arxiv.org/abs/2306.03310)
- 💻 [Original LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO)
To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each others work.
LIBERO includes **five task suites**:
- **LIBERO-Spatial (`libero_spatial`)** tasks that require reasoning about spatial relations.
- **LIBERO-Object (`libero_object`)** tasks centered on manipulating different objects.
- **LIBERO-Goal (`libero_goal`)** goal-conditioned tasks where the robot must adapt to changing targets.
- **LIBERO-90 (`libero_90`)** 90 short-horizon tasks from the LIBERO-100 collection.
- **LIBERO-Long (`libero_10`)** 10 long-horizon tasks from the LIBERO-100 collection.
Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms.
- Paper: [Benchmarking Knowledge Transfer for Lifelong Robot Learning](https://arxiv.org/abs/2306.03310)
- GitHub: [Lifelong-Robot-Learning/LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO)
- Project website: [libero-project.github.io](https://libero-project.github.io)
![An overview of the LIBERO benchmark](https://libero-project.github.io/assets/img/libero/fig1.png)
## Evaluating with LIBERO
## Available tasks
At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it mainly to **evaluate [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model.
LIBERO includes **five task suites** covering **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios:
LIBERO is now part of our **multi-eval supported simulation**, meaning you can benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a flag.
| Suite | CLI name | Tasks | Description |
| -------------- | ---------------- | ----- | -------------------------------------------------- |
| LIBERO-Spatial | `libero_spatial` | 10 | Tasks requiring reasoning about spatial relations |
| LIBERO-Object | `libero_object` | 10 | Tasks centered on manipulating different objects |
| LIBERO-Goal | `libero_goal` | 10 | Goal-conditioned tasks with changing targets |
| LIBERO-90 | `libero_90` | 90 | Short-horizon tasks from the LIBERO-100 collection |
| LIBERO-Long | `libero_10` | 10 | Long-horizon tasks from the LIBERO-100 collection |
To Install LIBERO, after following LeRobot official instructions, just do:
`pip install -e ".[libero]"`
## Installation
After following the LeRobot installation instructions:
```bash
pip install -e ".[libero]"
```
<Tip>
LIBERO requires Linux (`sys_platform == 'linux'`). LeRobot uses MuJoCo for simulation — set the rendering backend before training or evaluation:
```bash
export MUJOCO_GL=egl # for headless servers (HPC, cloud)
```
</Tip>
## Evaluation
### Default evaluation (recommended)
Evaluate across the four standard suites (10 episodes per task):
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--env.max_parallel_tasks=1
```
### Single-suite evaluation
Evaluate a policy on one LIBERO suite:
Evaluate on one LIBERO suite:
```bash
lerobot-eval \
@@ -42,15 +67,13 @@ lerobot-eval \
```
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
- `--env.task_ids` restricts to specific task indices (`[0]`, `[1,2,3]`, etc.). Omit to run all tasks in the suite.
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run in total.
---
- `--eval.n_episodes` sets how many episodes to run per task.
### Multi-suite evaluation
Benchmark a policy across multiple suites at once:
Benchmark a policy across multiple suites at once by passing a comma-separated list:
```bash
lerobot-eval \
@@ -61,50 +84,49 @@ lerobot-eval \
--eval.n_episodes=2
```
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
### Control mode
### Control Mode
LIBERO supports two control modes — `relative` (default) and `absolute`. Different VLA checkpoints are trained with different action parameterizations, so make sure the mode matches your policy:
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
```bash
--env.control_mode=relative # or "absolute"
```
### Policy inputs and outputs
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
**Observations:**
- **Observations**
- `observation.state` proprioceptive features (agent state).
- `observation.images.image` main camera view (`agentview_image`).
- `observation.images.image2` wrist camera view (`robot0_eye_in_hand_image`).
- `observation.state` — 8-dim proprioceptive features (eef position, axis-angle orientation, gripper qpos)
- `observation.images.image` — main camera view (`agentview_image`), HWC uint8
- `observation.images.image2` — wrist camera view (`robot0_eye_in_hand_image`), HWC uint8
⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any multi-modal visual features. Always ensure that your policy config `input_features` use the same naming keys, and that your dataset metadata keys follow this convention during evaluation.
If your data contains different keys, you must rename the observations to match what the policy expects, since naming keys are encoded inside the normalization statistics layer.
This will be fixed with the upcoming Pipeline PR.
<Tip warning={true}>
LeRobot enforces the `.images.*` prefix for visual features. Ensure your
policy config `input_features` use the same naming keys, and that your dataset
metadata keys follow this convention. If your data contains different keys,
you must rename the observations to match what the policy expects, since
naming keys are encoded inside the normalization statistics layer.
</Tip>
- **Actions**
- Continuous control values in a `Box(-1, 1, shape=(7,))` space.
**Actions:**
We also provide a notebook for quick testing:
Training with LIBERO
- Continuous control in `Box(-1, 1, shape=(7,))` — 6D end-effector delta + 1D gripper
## Training with LIBERO
### Recommended evaluation episodes
When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention.
For reproducible benchmarking, use **10 episodes per task** across all four standard suites (Spatial, Object, Goal, Long). This gives 400 total episodes and matches the protocol used for published results.
The environment expects:
## Training
- `observation.state` → 8-dim agent state
- `observation.images.image` → main camera (`agentview_image`)
- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`)
### Dataset
⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code.
To avoid potential mismatches and key errors, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulation:
👉 [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
We provide a preprocessed LIBERO dataset fully compatible with LeRobot:
For reference, here is the **original dataset** published by Physical Intelligence:
👉 [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
- [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
---
For reference, the original dataset published by Physical Intelligence:
- [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
### Example training command
@@ -121,52 +143,39 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval_freq=1000 \
--eval_freq=1000
```
---
## Reproducing published results
### Note on rendering
We reproduce the results of Pi0.5 on the LIBERO benchmark. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
The finetuned model: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
## Reproducing π₀.₅ results
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
The finetuned model can be found here:
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
### Evaluation command
```bash
lerobot-eval \
--output_dir=/logs/ \
--output_dir=./eval_logs/ \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--policy.path=pi05_libero_finetuned \
--policy.n_action_steps=10 \
--output_dir=./eval_logs/ \
--env.max_parallel_tasks=1
```
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
We set `n_action_steps=10`, matching the original OpenPI implementation.
### Results
We obtain the following results on the LIBERO benchmark:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| ------------------- | -------------- | ------------- | ----------- | --------- | -------- |
| **Pi0.5 (LeRobot)** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
These results are consistent with the [original results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| ------------------ | -------------- | ------------- | ----------- | --------- | --------- |
| **Pi0.5 (OpenPI)** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |

View File

@@ -1,32 +1,111 @@
# Meta-World
Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics.
Meta-World is an open-source simulation benchmark for **multi-task and meta reinforcement learning** in continuous-control robotic manipulation. It bundles 50 diverse manipulation tasks using everyday objects and a common tabletop Sawyer arm, providing a standardized playground to test whether algorithms can learn many different tasks and generalize quickly to new ones.
- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897)
- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld)
- Paper: [Meta-World: A Benchmark and Evaluation for Multi-Task and Meta Reinforcement Learning](https://arxiv.org/abs/1910.10897)
- GitHub: [Farama-Foundation/Metaworld](https://github.com/Farama-Foundation/Metaworld)
- Project website: [metaworld.farama.org](https://metaworld.farama.org)
![MetaWorld MT10 demo](https://meta-world.github.io/figures/ml45.gif)
## Why Meta-World matters
## Available tasks
- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure.
- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task.
- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes.
- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions.
Meta-World provides 50 tasks organized into difficulty groups. In LeRobot, you can evaluate on individual tasks, difficulty groups, or the full MT50 suite:
## What it enables in LeRobot
| Group | CLI name | Tasks | Description |
| ---------- | -------------------- | ----- | ------------------------------------------------------ |
| Easy | `easy` | 28 | Tasks with simple dynamics and single-step goals |
| Medium | `medium` | 11 | Tasks requiring multi-step reasoning |
| Hard | `hard` | 6 | Tasks with complex contacts and precise manipulation |
| Very Hard | `very_hard` | 5 | The most challenging tasks in the suite |
| MT50 (all) | Comma-separated list | 50 | All 50 tasks — the most challenging multi-task setting |
In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward:
You can also pass individual task names directly (e.g., `assembly-v3`, `dial-turn-v3`).
- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`.
- This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting).
- MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency.
We provide a LeRobot-ready dataset for Meta-World MT50 on the HF Hub: [lerobot/metaworld_mt50](https://huggingface.co/datasets/lerobot/metaworld_mt50). This dataset is formatted for the MT50 evaluation that uses all 50 tasks with fixed object/goal positions and one-hot task vectors for consistency.
- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals.
## Installation
## Quick start, train a SmolVLA policy on Meta-World
After following the LeRobot installation instructions:
Example command to train a SmolVLA policy on a subset of tasks:
```bash
pip install -e ".[metaworld]"
```
<Tip warning={true}>
If you encounter an `AssertionError: ['human', 'rgb_array', 'depth_array']` when running Meta-World environments, this is a mismatch between Meta-World and your Gymnasium version. Fix it with:
```bash
pip install "gymnasium==1.1.0"
```
</Tip>
## Evaluation
### Default evaluation (recommended)
Evaluate on the medium difficulty split (a good balance of coverage and compute):
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=medium \
--eval.batch_size=1 \
--eval.n_episodes=10
```
### Single-task evaluation
Evaluate on a specific task:
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=assembly-v3 \
--eval.batch_size=1 \
--eval.n_episodes=10
```
### Multi-task evaluation
Evaluate across multiple tasks or difficulty groups:
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
--eval.batch_size=1 \
--eval.n_episodes=10
```
- `--env.task` accepts explicit task lists (comma-separated) or difficulty groups (e.g., `easy`, `medium`, `hard`, `very_hard`).
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run per task.
### Policy inputs and outputs
**Observations:**
- `observation.image` — single camera view (`corner2`), 480x480 HWC uint8
- `observation.state` — 4-dim proprioceptive state (end-effector position + gripper)
**Actions:**
- Continuous control in `Box(-1, 1, shape=(4,))` — 3D end-effector delta + 1D gripper
### Recommended evaluation episodes
For reproducible benchmarking, use **10 episodes per task**. For the full MT50 suite this gives 500 total episodes. If you care about generalization, run on the full MT50 — it is intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
## Training
### Example training command
Train a SmolVLA policy on a subset of Meta-World tasks:
```bash
lerobot-train \
@@ -44,37 +123,8 @@ lerobot-train \
--eval_freq=1000
```
Notes:
- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`).
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
- **Gymnasium Assertion Error**: if you encounter an error like
`AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version.
We recommend using:
```bash
pip install "gymnasium==1.1.0"
```
to ensure proper compatibility.
## Quick start — evaluate a trained policy
To evaluate a trained policy on the Meta-World medium difficulty split:
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=medium \
--eval.batch_size=1 \
--eval.n_episodes=2
```
This will run episodes and return per-task success rates using the standard Meta-World evaluation keys.
## Practical tips
- If you care about generalization, run on the full MT50 suite — its intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context.
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.

View File

@@ -0,0 +1,680 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result.
Usage:
python examples/dataset/create_progress_videos.py \
--repo-id lerobot-data-collection/level2_final_quality3 \
--episode 1100
python examples/dataset/create_progress_videos.py \
--repo-id lerobot-data-collection/level2_final_quality3 \
--episode 1100 \
--camera-key observation.images.top \
--output-dir ./my_videos \
--gif
"""
from __future__ import annotations
import argparse
import json
import logging
import subprocess
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
GRAPH_Y_TOP_FRAC = 0.01
GRAPH_Y_BOT_FRAC = 0.99
LINE_THICKNESS = 3
SHADOW_THICKNESS = 6
REF_ALPHA = 0.45
FILL_ALPHA = 0.55
SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode_metadata(repo_id: str, episode: int) -> Path:
"""Download only the metadata and sarm_progress files for a dataset.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched).
Returns:
Local cache path for the downloaded snapshot.
"""
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
local_path = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"],
ignore_patterns=["*.mp4"],
)
)
return local_path
def load_episode_meta(local_path: Path, episode: int, camera_key: str | None) -> dict:
"""Read info.json and episode parquet to resolve fps, video path, and timestamps.
Args:
local_path: Local cache directory containing meta/.
episode: Episode index to look up.
camera_key: Camera observation key (e.g. "observation.images.base").
If None, the first available video key is used.
Returns:
Dict with keys: fps, camera, video_rel, chunk_index, file_index,
from_ts, to_ts, task_name.
"""
info = json.loads((local_path / "meta" / "info.json").read_text())
fps = info["fps"]
features = info["features"]
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
if not video_keys:
raise RuntimeError("No video keys found in dataset features")
if camera_key is not None:
if camera_key not in video_keys:
raise RuntimeError(f"camera_key='{camera_key}' not found. Available: {video_keys}")
selected_camera = camera_key
else:
selected_camera = video_keys[0]
logging.info(" fps=%d camera='%s' all_cams=%s", fps, selected_camera, video_keys)
episode_rows = []
for parquet_file in sorted((local_path / "meta" / "episodes").glob("**/*.parquet")):
episode_rows.append(pd.read_parquet(parquet_file))
episode_df = pd.concat(episode_rows, ignore_index=True)
row = episode_df[episode_df["episode_index"] == episode]
if row.empty:
raise RuntimeError(f"Episode {episode} not found in episode metadata")
row = row.iloc[0]
chunk_col = f"videos/{selected_camera}/chunk_index"
file_col = f"videos/{selected_camera}/file_index"
ts_from_col = f"videos/{selected_camera}/from_timestamp"
ts_to_col = f"videos/{selected_camera}/to_timestamp"
if chunk_col not in row.index:
chunk_col = f"{selected_camera}/chunk_index"
file_col = f"{selected_camera}/file_index"
ts_from_col = f"{selected_camera}/from_timestamp"
ts_to_col = f"{selected_camera}/to_timestamp"
if chunk_col not in row.index:
raise RuntimeError(
f"Cannot find video metadata columns for {selected_camera}.\nAvailable: {list(row.index)}"
)
chunk_index = int(row[chunk_col])
file_index = int(row[file_col])
from_timestamp = float(row[ts_from_col])
to_timestamp = float(row[ts_to_col])
video_template = info.get(
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
)
video_rel = video_template.format(
video_key=selected_camera,
chunk_index=chunk_index,
file_index=file_index,
)
task_name = _resolve_task_name(row, local_path)
return {
"fps": fps,
"camera": selected_camera,
"video_rel": video_rel,
"chunk_index": chunk_index,
"file_index": file_index,
"from_ts": from_timestamp,
"to_ts": to_timestamp,
"task_name": task_name,
}
def _resolve_task_name(row: pd.Series, local_path: Path) -> str:
"""Best-effort extraction of the task name for an episode row.
Args:
row: Single-episode row from the episodes parquet.
local_path: Dataset cache root.
Returns:
Task name string, or empty string if unavailable.
"""
try:
if "tasks" in row.index and row["tasks"] is not None:
tasks_val = row["tasks"]
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
return str(tasks_val[0])
return str(tasks_val).strip("[]'")
tasks_parquet = local_path / "meta" / "tasks.parquet"
if tasks_parquet.exists():
tasks_df = pd.read_parquet(tasks_parquet)
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
match = tasks_df[tasks_df["task_index"] == task_idx]
if not match.empty:
return str(match.index[0])
except Exception as exc:
logging.warning("Could not load task name: %s", exc)
return ""
def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
"""Download the specific video file if not already cached.
Args:
repo_id: HuggingFace dataset repository ID.
local_path: Local cache directory.
video_rel: Relative path to the video file within the dataset.
Returns:
Absolute path to the downloaded video file.
"""
video_path = local_path / video_rel
if video_path.exists():
logging.info(" Video already cached: %s", video_path)
return video_path
logging.info("[2/4] Downloading video file %s ...", video_rel)
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=str(local_path),
allow_patterns=[video_rel],
)
if not video_path.exists():
raise RuntimeError(f"Video not found after download: {video_path}")
return video_path
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for an episode.
Args:
local_path: Dataset cache root.
episode: Episode index.
Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
"""
parquet_path = local_path / "sarm_progress.parquet"
if not parquet_path.exists():
logging.warning("sarm_progress.parquet not found")
return None
df = pd.read_parquet(parquet_path)
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty:
logging.warning("No sarm_progress rows for episode %d", episode)
return None
episode_df = episode_df.sort_values("frame_index")
if "progress_dense" in episode_df.columns and episode_df["progress_dense"].notna().any():
progress_column = "progress_dense"
elif "progress_sparse" in episode_df.columns:
progress_column = "progress_sparse"
else:
progress_columns = [c for c in episode_df.columns if "progress" in c.lower()]
if not progress_columns:
return None
progress_column = progress_columns[0]
logging.info(" Using progress column: '%s'", progress_column)
return episode_df[["frame_index", progress_column]].rename(columns={progress_column: "progress"}).values
def _precompute_pixel_coords(
progress_data: np.ndarray,
num_frames: int,
frame_width: int,
frame_height: int,
) -> np.ndarray:
"""Map progress samples to pixel coordinates for overlay drawing.
Args:
progress_data: (N, 2) array of (frame_index, progress).
num_frames: Total number of video frames.
frame_width: Video width in pixels.
frame_height: Video height in pixels.
Returns:
(N, 2) array of (x, y) pixel coordinates.
"""
frame_indices = progress_data[:, 0].astype(float)
progress_values = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
y_top = int(frame_height * GRAPH_Y_TOP_FRAC)
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
graph_height = y_bot - y_top
x_coords = (frame_indices / (num_frames - 1) * (frame_width - 1)).astype(int)
y_coords = (y_bot - progress_values * graph_height).astype(int)
return np.stack([x_coords, y_coords], axis=1)
def _progress_color(normalized_position: float) -> tuple[int, int, int]:
"""Interpolate BGR color from red to green based on position in [0, 1].
Args:
normalized_position: Value in [0, 1] indicating how far along the episode.
Returns:
BGR color tuple.
"""
red = int(255 * (1.0 - normalized_position))
green = int(255 * normalized_position)
return (0, green, red)
def _prerender_fill_polygon(
pixel_coords: np.ndarray,
frame_width: int,
frame_height: int,
) -> np.ndarray:
"""Pre-render the grey fill polygon under the progress curve as a BGRA image.
Args:
pixel_coords: (N, 2) array of (x, y) pixel coordinates.
frame_width: Video width in pixels.
frame_height: Video height in pixels.
Returns:
BGRA image array of shape (frame_height, frame_width, 4).
"""
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
fill_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
polygon = np.concatenate(
[
pixel_coords,
[[pixel_coords[-1][0], y_bot], [pixel_coords[0][0], y_bot]],
],
axis=0,
).astype(np.int32)
cv2.fillPoly(fill_image, [polygon], color=(128, 128, 128, int(255 * FILL_ALPHA)))
return fill_image
def _alpha_composite_region(base: np.ndarray, overlay_bgra: np.ndarray, x_limit: int) -> None:
"""Blend BGRA overlay onto BGR base in-place, up to x_limit columns.
Args:
base: BGR frame to draw on (modified in-place).
overlay_bgra: BGRA overlay image.
x_limit: Only blend columns [0, x_limit).
"""
if x_limit <= 0:
return
region_base = base[:, :x_limit]
region_overlay = overlay_bgra[:, :x_limit]
alpha = region_overlay[:, :, 3:4].astype(np.float32) / 255.0
region_base[:] = np.clip(
region_overlay[:, :, :3].astype(np.float32) * alpha + region_base.astype(np.float32) * (1.0 - alpha),
0,
255,
).astype(np.uint8)
def _draw_text_outlined(
frame: np.ndarray,
text: str,
position: tuple[int, int],
font_scale: float,
thickness: int = 1,
) -> None:
"""Draw white text with a dark outline for readability on any background.
Args:
frame: BGR image to draw on (modified in-place).
text: String to render.
position: (x, y) bottom-left corner of the text.
font_scale: OpenCV font scale.
thickness: Text stroke thickness.
"""
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(frame, text, position, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
cv2.putText(frame, text, position, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
def composite_progress_video(
video_path: Path,
from_timestamp: float,
to_timestamp: float,
progress_data: np.ndarray,
output_path: Path,
fps: float,
task_name: str = "",
) -> Path:
"""Read episode frames by seeking into the source video, draw progress overlay, write output.
Uses cv2.CAP_PROP_POS_MSEC to seek directly into the source video,
eliminating the need for an intermediate clip file.
Args:
video_path: Path to the full source video file.
from_timestamp: Start timestamp of the episode in seconds.
to_timestamp: End timestamp of the episode in seconds.
progress_data: (N, 2) array of (frame_index, progress).
output_path: Path to write the output MP4.
fps: Frames per second for the output video.
task_name: Optional task name to display at the top of the video.
Returns:
Path to the written output file (MP4).
"""
capture = cv2.VideoCapture(str(video_path))
try:
capture.set(cv2.CAP_PROP_POS_MSEC, from_timestamp * 1000)
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
duration_seconds = to_timestamp - from_timestamp
num_frames = int(round(duration_seconds * fps))
logging.info(
" Video: %dx%d, %d frames @ %.1f fps (%.2fs)",
frame_width,
frame_height,
num_frames,
fps,
duration_seconds,
)
pixel_coords = _precompute_pixel_coords(progress_data, num_frames, frame_width, frame_height)
y_ref = int(frame_height * GRAPH_Y_TOP_FRAC)
fill_image = _prerender_fill_polygon(pixel_coords, frame_width, frame_height)
ref_line_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
cv2.line(
ref_line_image,
(0, y_ref),
(frame_width - 1, y_ref),
(200, 200, 200, int(255 * REF_ALPHA)),
1,
cv2.LINE_AA,
)
frame_indices = progress_data[:, 0].astype(int)
progress_values = progress_data[:, 1].astype(float)
logging.info("[3/4] Compositing %d frames ...", num_frames)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
for frame_idx in range(num_frames):
ret, frame = capture.read()
if not ret:
break
drawn_count = int(np.searchsorted(frame_indices, frame_idx, side="right"))
x_current = (
int(pixel_coords[min(drawn_count, len(pixel_coords)) - 1][0]) + 1 if drawn_count > 0 else 0
)
_alpha_composite_region(frame, ref_line_image, frame_width)
_alpha_composite_region(frame, fill_image, x_current)
if drawn_count >= 2:
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
line_color = _progress_color(time_position)
points = pixel_coords[:drawn_count].reshape(-1, 1, 2).astype(np.int32)
cv2.polylines(
frame,
[points],
isClosed=False,
color=(255, 255, 255),
thickness=SHADOW_THICKNESS,
lineType=cv2.LINE_AA,
)
cv2.polylines(
frame,
[points],
isClosed=False,
color=line_color,
thickness=LINE_THICKNESS,
lineType=cv2.LINE_AA,
)
if drawn_count > 0:
score = float(progress_values[min(drawn_count, len(progress_values)) - 1])
score_text = f"{score:.2f}"
(text_width, _), _ = cv2.getTextSize(
score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2
)
score_x = frame_width - text_width - 12
score_y = frame_height - 12
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
score_color = _progress_color(time_position)
cv2.putText(
frame,
score_text,
(score_x, score_y),
cv2.FONT_HERSHEY_SIMPLEX,
SCORE_FONT_SCALE,
(0, 0, 0),
4,
cv2.LINE_AA,
)
cv2.putText(
frame,
score_text,
(score_x, score_y),
cv2.FONT_HERSHEY_SIMPLEX,
SCORE_FONT_SCALE,
score_color,
2,
cv2.LINE_AA,
)
if task_name:
(text_width, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
task_x = max((frame_width - text_width) // 2, 4)
_draw_text_outlined(frame, task_name, (task_x, 22), TASK_FONT_SCALE)
writer.write(frame)
if frame_idx % 100 == 0:
logging.info(" Frame %d/%d ...", frame_idx, num_frames)
writer.release()
finally:
capture.release()
logging.info(" MP4 written: %s", output_path)
return output_path
def convert_mp4_to_gif(mp4_path: Path) -> Path:
"""Convert an MP4 to an optimized GIF using ffmpeg palette generation.
Args:
mp4_path: Path to the source MP4 file.
Returns:
Path to the generated GIF file.
"""
capture = cv2.VideoCapture(str(mp4_path))
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
capture.release()
gif_path = mp4_path.with_suffix(".gif")
palette_path = mp4_path.parent / "_palette.png"
logging.info("[4/4] Converting to GIF ...")
result_palette = subprocess.run( # nosec B607
[
"ffmpeg",
"-y",
"-i",
str(mp4_path),
"-vf",
f"fps=10,scale={frame_width}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
"-update",
"1",
str(palette_path),
],
capture_output=True,
text=True,
)
if result_palette.returncode != 0:
logging.warning("palettegen failed:\n%s", result_palette.stderr[-500:])
result_gif = subprocess.run( # nosec B607
[
"ffmpeg",
"-y",
"-i",
str(mp4_path),
"-i",
str(palette_path),
"-filter_complex",
f"fps=10,scale={frame_width}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
str(gif_path),
],
capture_output=True,
text=True,
)
if result_gif.returncode != 0:
logging.warning("GIF encode failed:\n%s", result_gif.stderr[-500:])
palette_path.unlink(missing_ok=True)
logging.info(" GIF written: %s", gif_path)
return gif_path
def process_dataset(
repo_id: str,
episode: int,
camera_key: str | None,
output_dir: Path,
create_gif: bool = False,
) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index.
camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4.
Returns:
Path to the final output file, or None on failure.
"""
safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode)
logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key)
logging.info(" Episode meta: %s", episode_meta)
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode)
if progress_data is None:
logging.error("Could not load sarm_progress data. Skipping overlay.")
return None
logging.info(" Progress frames: %d", len(progress_data))
output_path = output_dir / f"{safe_name}_ep{episode}_progress.mp4"
final_path = composite_progress_video(
video_path=video_path,
from_timestamp=episode_meta["from_ts"],
to_timestamp=episode_meta["to_ts"],
progress_data=progress_data,
output_path=output_path,
fps=episode_meta["fps"],
task_name=episode_meta.get("task_name", ""),
)
if create_gif:
final_path = convert_mp4_to_gif(final_path)
logging.info("Done: %s", final_path)
return final_path
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g. 'lerobot-data-collection/level2_final_quality3').",
)
parser.add_argument(
"--episode",
type=int,
required=True,
help="Episode index to visualize.",
)
parser.add_argument(
"--camera-key",
type=str,
default=None,
help="Camera observation key (e.g. 'observation.images.base'). Auto-selects first camera if omitted.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("progress_videos"),
help="Directory to write output files (default: ./progress_videos).",
)
parser.add_argument(
"--gif",
action="store_true",
help="Also generate a GIF from the MP4 output.",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
args.output_dir.mkdir(parents=True, exist_ok=True)
result = process_dataset(
repo_id=args.repo_id,
episode=args.episode,
camera_key=args.camera_key,
output_dir=args.output_dir,
create_gif=args.gif,
)
if result:
logging.info("Output: %s", result)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

228
examples/hil/hil_utils.py Normal file
View File

@@ -0,0 +1,228 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
vcodec: str = "auto"
streaming_encoding: bool = True
encoder_queue_maxsize: int = 30
encoder_threads: int | None = None
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"resume_policy": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
logger.info("[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] Taking control...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "p":
if events["policy_paused"] or events["correction_active"]:
logger.info("[HIL] Resuming policy...")
events["resume_policy"] = True
elif key == keyboard.Key.right:
logger.info("[HIL] End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
logger.info("[HIL] Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
logger.info(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
logger.info("[HIL] RESET")
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
logger.info("Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
logger.info("Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
logger.info(
"%s\n Controls:\n"
" SPACE - Pause policy\n"
" c - Take control\n"
" p - Resume policy after pause/correction\n"
" → - End episode\n"
" ESC - Stop and push to hub",
mode,
)

View File

@@ -43,13 +43,12 @@ def main():
keyboard.connect()
# Init rerun viewer
init_rerun(session_name="lekiwi_teleop", robot=robot, reset_time=True)
init_rerun(session_name="lekiwi_teleop")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting teleop loop...")
start = time.perf_counter()
while True:
t0 = time.perf_counter()
@@ -70,7 +69,7 @@ def main():
_ = robot.send_action(action)
# Visualize
log_rerun_data(observation=observation, action=action, log_time=time.perf_counter() - start)
log_rerun_data(observation=observation, action=action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -90,13 +90,12 @@ def main():
teleop_device.connect()
# Init rerun viewer
init_rerun(session_name="phone_so100_teleop", robot=robot, reset_time=True)
init_rerun(session_name="phone_so100_teleop")
if not robot.is_connected or not teleop_device.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting teleop loop. Move your phone to teleoperate the robot...")
start = time.perf_counter()
while True:
t0 = time.perf_counter()
@@ -113,7 +112,7 @@ def main():
_ = robot.send_action(joint_action)
# Visualize
log_rerun_data(observation=phone_obs, action=joint_action, log_time=time.perf_counter() - start)
log_rerun_data(observation=phone_obs, action=joint_action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -69,15 +69,20 @@ Usage:
--policy.path=lerobot-data-collection/folding_final \
--robot.type=bi_openarm_follower \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.left_arm_config.can_interface=socketcan \
--robot.right_arm_config.port=can0 \
--robot.left_arm_config.disable_torque_on_disconnect=true \
--robot.left_arm_config.max_relative_target=8.0 \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.right_arm_config.can_interface=socketcan \
--robot.right_arm_config.disable_torque_on_disconnect=true \
--robot.right_arm_config.max_relative_target=8.0 \
--task="Fold the T-shirt properly" \
--fps=30 \
--duration=2000 \
--interpolation_multiplier=3 \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
@@ -104,9 +109,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
@@ -181,6 +184,7 @@ class RTCDemoConfig(HubMixin):
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
@@ -461,20 +465,23 @@ def actor_control(
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
action_count = 0
action_interval = 1.0 / cfg.fps
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time

View File

@@ -95,10 +95,9 @@ def main():
leader.connect()
# Init rerun viewer
init_rerun(session_name="so100_so100_EE_teleop", robot=follower, reset_time=True)
init_rerun(session_name="so100_so100_EE_teleop")
print("Starting teleop loop...")
start = time.perf_counter()
while True:
t0 = time.perf_counter()
@@ -118,9 +117,7 @@ def main():
_ = follower.send_action(follower_joints_act)
# Visualize
log_rerun_data(
observation=leader_ee_act, action=follower_joints_act, log_time=time.perf_counter() - start
)
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -164,7 +164,6 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
audio = ["sounddevice>=0.5.1,<0.6.0", "soundfile>=0.13.1,<0.14.0", "librosa>=0.11.0,<0.12.0", "torchaudio>=2.6.0,<2.10.0"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -199,7 +198,6 @@ all = [
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[audio]",
"lerobot[dev]",
"lerobot[test]",
"lerobot[video_benchmark]",

View File

@@ -29,7 +29,6 @@ Example:
print(lerobot.available_policies_per_env)
print(lerobot.available_robots)
print(lerobot.available_cameras)
print(lerobot.available_microphones)
print(lerobot.available_motors)
```
@@ -175,12 +174,6 @@ available_cameras = [
"intelrealsense",
]
# lists all available microphones from `lerobot/microphones`
available_microphones = [
"portaudio",
"touchlab",
]
# lists all available motors from `lerobot/motors`
available_motors = [
"dynamixel",

View File

@@ -49,8 +49,6 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,

View File

@@ -67,7 +67,8 @@ class EvalConfig:
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
batch_size: int = 50
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
use_async_envs: bool = True
def __post_init__(self) -> None:
if self.batch_size > self.n_episodes:

View File

@@ -151,12 +151,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
return {}
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property
def audio_features(self) -> dict[str, PolicyFeature]:
if not self.input_features:
return {}
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
@property
def action_feature(self) -> PolicyFeature | None:
if not self.output_features:

View File

@@ -20,7 +20,6 @@ from enum import Enum
class FeatureType(str, Enum):
STATE = "STATE"
VISUAL = "VISUAL"
AUDIO = "AUDIO"
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"

View File

@@ -35,8 +35,6 @@ from lerobot.datasets.io_utils import (
write_tasks,
)
from lerobot.datasets.utils import (
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
DEFAULT_AUDIO_PATH,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -45,7 +43,7 @@ from lerobot.datasets.utils import (
DEFAULT_VIDEO_PATH,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
@@ -114,7 +112,6 @@ def update_meta_data(
meta_idx,
data_idx,
videos_idx,
audios_idx,
):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
@@ -130,7 +127,7 @@ def update_meta_data(
meta_idx: Dictionary containing current metadata chunk and file indices.
data_idx: Dictionary containing current data chunk and file indices.
videos_idx: Dictionary containing current video indices and timestamps.
audios_idx: Dictionary containing current audio indices and timestamps.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
"""
@@ -228,36 +225,6 @@ def update_meta_data(
# Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"])
for key, audio_idx in audios_idx.items():
# Store original audio file indices before updating
orig_chunk_col = f"audio/{key}/chunk_index"
orig_file_col = f"audio/{key}/file_index"
df["_orig_chunk"] = df[orig_chunk_col].copy()
df["_orig_file"] = df[orig_file_col].copy()
# Update chunk and file indices to point to destination
df[orig_chunk_col] = audio_idx["chunk"]
df[orig_file_col] = audio_idx["file"]
# Apply per-source-file timestamp offsets
src_to_offset = audio_idx.get("src_to_offset", {})
if src_to_offset:
# Apply offset based on original source file
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"audio/{key}/from_timestamp"] += offset
df.at[idx, f"audio/{key}/to_timestamp"] += offset
else:
# Fallback to simple offset (for backward compatibility)
df[f"audio/{key}/from_timestamp"] = (
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
)
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
# Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"])
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
@@ -272,7 +239,6 @@ def aggregate_datasets(
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
audio_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -290,7 +256,6 @@ def aggregate_datasets(
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
"""
logging.info("Start aggregate_datasets")
@@ -299,8 +264,6 @@ def aggregate_datasets(
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if audio_files_size_in_mb is None:
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
@@ -313,7 +276,6 @@ def aggregate_datasets(
)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
dst_meta = LeRobotDatasetMetadata.create(
repo_id=aggr_repo_id,
@@ -325,7 +287,6 @@ def aggregate_datasets(
chunks_size=chunk_size,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
audio_files_size_in_mb=audio_files_size_in_mb,
)
logging.info("Find all tasks")
@@ -339,18 +300,14 @@ def aggregate_datasets(
videos_idx = {
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
}
audios_idx = {
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
}
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
# Clear the src_to_dst mapping after processing each source dataset
# to avoid interference between different source datasets
@@ -418,7 +375,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
file_index=file_idx,
)
src_duration = get_media_duration_in_s(src_path, media_type="video")
src_duration = get_video_duration_in_s(src_path)
dst_key = (chunk_idx, file_idx)
if not dst_path.exists():
@@ -457,7 +414,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
current_dst_duration = dst_file_durations.get(dst_key, 0)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
concatenate_media_files(
concatenate_video_files(
[dst_path, src_path],
dst_path,
)
@@ -472,101 +429,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
return videos_idx
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
"""Aggregates audio files from a source dataset into the destination dataset.
Handles audio file concatenation and rotation based on file size limits.
Creates new audio files when size limits are exceeded.
Args:
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
audio_idx: Dictionary tracking audio chunk and file indices.
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns:
dict: Updated audio_idx with current chunk and file indices.
"""
for key in audios_idx:
audios_idx[key]["episode_duration"] = 0
# Track offset for each source (chunk, file) pair
audios_idx[key]["src_to_offset"] = {}
for key, audio_idx in audios_idx.items():
unique_chunk_file_pairs = {
(chunk, file)
for chunk, file in zip(
src_meta.episodes[f"audio/{key}/chunk_index"],
src_meta.episodes[f"audio/{key}/file_index"],
strict=False,
)
}
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
chunk_idx = audio_idx["chunk"]
file_idx = audio_idx["file"]
current_offset = audio_idx["latest_duration"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
audio_key=key,
chunk_index=src_chunk_idx,
file_index=src_file_idx,
)
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
audio_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
src_duration = get_media_duration_in_s(src_path, media_type="audio")
if not dst_path.exists():
# Store offset before incrementing
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
audios_idx[key]["episode_duration"] += src_duration
current_offset += src_duration
continue
# Check file sizes before appending
src_size = get_file_size_in_mb(src_path)
dst_size = get_file_size_in_mb(dst_path)
if dst_size + src_size >= audio_files_size_in_mb:
# Rotate to a new file, this source becomes start of new destination
# So its offset should be 0
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
audio_key=key,
chunk_index=chunk_idx,
file_index=file_idx,
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Reset offset for next file
current_offset = src_duration
else:
# Append to existing video file - use current accumulated offset
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
concatenate_media_files(
[dst_path, src_path],
dst_path,
)
current_offset += src_duration
audios_idx[key]["episode_duration"] += src_duration
audios_idx[key]["chunk"] = chunk_idx
audios_idx[key]["file"] = file_idx
return audios_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
"""Aggregates data chunks from a source dataset into the destination dataset.
@@ -639,7 +501,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
return data_idx
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.
Reads source metadata files, updates all indices and timestamps,
@@ -651,7 +513,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
meta_idx: Dictionary tracking metadata chunk and file indices.
data_idx: Dictionary tracking data chunk and file indices.
videos_idx: Dictionary tracking video indices and timestamps.
audios_idx: Dictionary tracking audio indices and timestamps.
Returns:
dict: Updated meta_idx with current chunk and file indices.
@@ -675,7 +536,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
meta_idx,
data_idx,
videos_idx,
audios_idx,
)
meta_idx, _ = append_or_create_parquet_file(
@@ -692,8 +552,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
# Increment latest_duration by the total duration added from this source dataset
for k in videos_idx:
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
for k in audios_idx:
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
return meta_idx

View File

@@ -1,275 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import av
import torch
import torchaudio
import torchcodec
from numpy import ceil
CHANNELS_LAYOUTS_MAPPING = {
1: "mono",
2: "stereo",
3: "2.1",
4: "3.1",
5: "4.1",
6: "5.1",
7: "6.1",
8: "7.1",
16: "hexadecagonal",
24: "22.2",
}
def decode_audio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
start_time_s: float | None = 0.0,
backend: str | None = "torchcodec",
) -> torch.Tensor:
"""
Decodes audio using the specified backend.
Args:
audio_path (Path): Path to the audio file.
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
duration (float): Duration of the audio chunks in seconds.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
Returns:
torch.Tensor: Decoded audio chunks.
Currently supports torchaudio.
"""
if backend == "torchcodec":
return decode_audio_torchcodec(audio_path, timestamps, duration, start_time_s)
elif backend == "torchaudio":
return decode_audio_torchaudio(audio_path, timestamps, duration, start_time_s)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_audio_torchcodec(
audio_path: Path | str,
timestamps: list[float],
duration: float,
start_time_s: float | None = 0.0,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
audio_sample_rate = audio_decoder.metadata.sample_rate
audio_channels = audio_decoder.metadata.num_channels
# TODO(CarolinePascal) : assert ts < total record duration
audio_chunks = []
timestamps = [
timestamp + start_time_s for timestamp in timestamps
] # Add an offset of start_time_s to each timestamp
for ts in timestamps:
current_audio_chunk = audio_decoder.get_samples_played_in_range(
start_seconds=max(0.0, ts - duration), stop_seconds=ts
)
current_audio_chunk_data = current_audio_chunk.data
# Case where the requested audio chunk starts before the beginning of the audio stream
if ts - duration < 0:
# No useful audio sample has been recorded
if ts < 1 / audio_sample_rate:
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.zeros(
(audio_channels, int(ceil(duration * audio_sample_rate)))
)
# At least one useful audio sample has been recorded
else:
# Pad the beginning of the audio chunk with zeros
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.nn.functional.pad(
current_audio_chunk_data,
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
)
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
)
audio_chunks.append(current_audio_chunk_data)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def decode_audio_torchaudio(
audio_path: Path | str,
timestamps: list[float],
duration: float,
start_time_s: float | None = 0.0,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
# TODO(CarolinePascal) : add channels selection
audio_path = str(audio_path)
reader = torchaudio.io.StreamReader(src=audio_path)
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
audio_channels = reader.get_src_stream_info(reader.default_audio_stream).num_channels
# TODO(CarolinePascal) : assert ts < total record duration
# TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames
format="fltp", # Format as float32
)
audio_chunks = []
timestamps = [
timestamp + start_time_s for timestamp in timestamps
] # Add an offset of start_time_s to each timestamp
for ts in timestamps:
reader.seek(max(0.0, ts - duration)) # Default to closest audio sample. Needs to be non-negative !
status = reader.fill_buffer()
if status != 0:
# Should not happen, but just in case
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
current_audio_chunk = reader.pop_chunks()[0]
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
# Case where the requested audio chunk starts before the beginning of the audio stream
if ts - duration < 0:
# No useful audio sample has been recorded
if ts < 1 / audio_sample_rate:
current_audio_chunk_data = torch.zeros(
(audio_channels, int(ceil(duration * audio_sample_rate)))
)
# At least one useful audio sample has been recorded
else:
# Remove the superfluous last samples of the audio chunk
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
# Pad the beginning of the audio chunk with zeros
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.nn.functional.pad(
current_audio_chunk_data,
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
)
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
)
audio_chunks.append(current_audio_chunk_data)
audio_chunks = torch.stack(audio_chunks)
assert len(timestamps) == len(audio_chunks)
return audio_chunks
def encode_audio(
input_path: Path | str,
output_path: Path | str,
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
bit_rate: int | None = None,
sample_rate: int | None = None,
log_level: int | None = av.logging.ERROR,
overwrite: bool = False,
) -> None:
"""Encodes an audio file using ffmpeg."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=overwrite)
# Set logging level
if log_level is not None:
# "While less efficient, it is generally preferable to modify logging with Pythons logging"
logging.getLogger("libav").setLevel(log_level)
# Open input file
with av.open(str(input_path), "r") as input:
input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded
# Define sub-sampling options
if sample_rate is None:
sample_rate = input_stream.rate
# Create and open output file (overwrite by default)
with av.open(str(output_path), "w") as output:
output_stream = output.add_stream(
codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels]
)
if bit_rate is not None:
output_stream.bit_rate = bit_rate
# Loop through input WAV packets and encode them
for input_frame in input.decode(
input_stream
): # This step handles both demuxing and decoding under the hood
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
# Flush the encoder
packet = output_stream.encode()
if packet:
output.mux(packet)
# Reset logging level
if log_level is not None:
av.logging.restore_default_callback()
if not output_path.exists():
raise OSError(f"Audio encoding did not work. File not found: {output_path}.")
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting audio stream information
audio_info = {}
with av.open(str(video_path), "r") as audio_file:
try:
audio_stream = audio_file.streams.audio[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {"has_audio": False}
audio_info["audio.channels"] = audio_stream.channels
audio_info["audio.codec"] = audio_stream.codec.canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info["audio.bit_rate"] = audio_stream.bit_rate
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info["audio.bit_depth"] = audio_stream.format.bits
audio_info["audio.channel_layout"] = audio_stream.layout.name
audio_info["has_audio"] = True
# Reset logging level
av.logging.restore_default_callback()
return audio_info

View File

@@ -19,7 +19,8 @@ import logging
import numpy as np
from lerobot.datasets.io_utils import load_audio_from_path, load_image_as_numpy
from lerobot.datasets.io_utils import load_image_as_numpy
from lerobot.utils.constants import ACTION, OBS_STATE
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -249,20 +250,6 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def sample_audio_from_path(audio_path: str) -> np.ndarray:
"""Samples audio data from an audio recording stored in a WAV file."""
data = load_audio_from_path(audio_path)
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
"""Samples audio data from an audio recording stored in a numpy array."""
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def _reshape_stats_by_axis(
stats: dict[str, np.ndarray],
axis: int | tuple[int, ...] | None,
@@ -530,13 +517,6 @@ def compute_episode_stats(
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3)
keepdims = True
elif features[key]["dtype"] == "audio":
try:
ep_ft_array = sample_audio_from_path(data[0])
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
ep_ft_array = sample_audio_from_data(data)
axes_to_reduce = 0
keepdims = True
else:
ep_ft_array = data
axes_to_reduce = 0

View File

@@ -23,7 +23,6 @@ import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.datasets.audio_utils import get_audio_info
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
from lerobot.datasets.io_utils import (
@@ -41,7 +40,6 @@ from lerobot.datasets.io_utils import (
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
INFO_PATH,
check_version_compatibility,
flatten_dict,
@@ -271,32 +269,6 @@ class LeRobotDatasetMetadata:
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path:
"""Return the relative audio file path for the given episode and audio key.
Args:
ep_index: Zero-based episode index.
audio_key: Feature key identifying the audio stream
(e.g. ``'observation.audio.microphone'``).
Returns:
Path to the audio file containing this episode's audio.
Raises:
IndexError: If ``ep_index`` is out of range.
"""
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep[f"audio/{audio_key}/chunk_index"]
file_idx = ep[f"audio/{audio_key}/file_index"]
fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
@@ -307,11 +279,6 @@ class LeRobotDatasetMetadata:
"""Formattable string for the video files."""
return self.info["video_path"]
@property
def audio_path(self) -> str | None:
"""Formattable string for the audio files."""
return self.info["audio_path"]
@property
def robot_type(self) -> str | None:
"""Robot type used in recording this dataset."""
@@ -342,11 +309,6 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def audio_keys(self) -> list[str]:
"""Keys to access audio modalities."""
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@@ -387,11 +349,6 @@ class LeRobotDatasetMetadata:
"""Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"]
@property
def audio_files_size_in_mb(self) -> int:
"""Max size of audio file in mega bytes."""
return self.info["audio_files_size_in_mb"]
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
@@ -558,27 +515,11 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_audio_info(self, audio_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
if audio_key is not None and audio_key not in self.audio_keys:
raise ValueError(f"Audio key {audio_key} not found in dataset")
audio_keys = [audio_key] if audio_key is not None else self.audio_keys
for key in audio_keys:
if not self.features[key].get("info", None):
audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_audio_info(audio_path)
self.info["features"][key]["info"]["start_time_s"] = DEFAULT_INITIAL_AUDIO_BUFFER_DURATION
def update_chunk_settings(
self,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
audio_files_size_in_mb: int | None = None,
) -> None:
"""Update chunk and file size settings after dataset creation.
@@ -590,7 +531,6 @@ class LeRobotDatasetMetadata:
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value.
"""
if chunks_size is not None:
if chunks_size <= 0:
@@ -607,11 +547,6 @@ class LeRobotDatasetMetadata:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb
if audio_files_size_in_mb is not None:
if audio_files_size_in_mb <= 0:
raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}")
self.info["audio_files_size_in_mb"] = audio_files_size_in_mb
# Update the info file on disk
write_info(self.info, self.root)
@@ -619,13 +554,12 @@ class LeRobotDatasetMetadata:
"""Get current chunk and file size settings.
Returns:
Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb.
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
"""
return {
"chunks_size": self.chunks_size,
"data_files_size_in_mb": self.data_files_size_in_mb,
"video_files_size_in_mb": self.video_files_size_in_mb,
"audio_files_size_in_mb": self.audio_files_size_in_mb,
}
def __repr__(self):
@@ -652,7 +586,6 @@ class LeRobotDatasetMetadata:
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
audio_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Create metadata for a new LeRobot dataset from scratch.
@@ -703,7 +636,6 @@ class LeRobotDatasetMetadata:
chunks_size,
data_files_size_in_mb,
video_files_size_in_mb,
audio_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError(

View File

@@ -21,14 +21,12 @@ from pathlib import Path
import datasets
import torch
from lerobot.datasets.audio_utils import decode_audio
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
check_delta_timestamps,
get_delta_indices,
get_hf_features_from_features,
)
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
load_nested_dataset,
@@ -132,7 +130,7 @@ class DatasetReader:
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video and audio files."""
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
@@ -156,13 +154,6 @@ class DatasetReader:
if not video_path.exists():
return False
if len(self._meta.audio_keys) > 0:
for ep_idx in requested_episodes:
for audio_key in self._meta.audio_keys:
audio_path = self.root / self._meta.get_compressed_audio_file_path(ep_idx, audio_key)
if not audio_path.exists():
return False
return True
def get_episodes_file_paths(self) -> list[Path]:
@@ -179,15 +170,6 @@ class DatasetReader:
for ep_idx in episodes
]
fpaths += video_files
if len(self._meta.audio_keys) > 0:
audio_files = [
str(self._meta.get_compressed_audio_file_path(ep_idx, audio_key))
for audio_key in self._meta.audio_keys
for ep_idx in episodes
]
fpaths += audio_files
# episodes are stored in the same files, so we return unique paths only
fpaths = list(set(fpaths))
return fpaths
@@ -217,7 +199,7 @@ class DatasetReader:
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self._meta.video_keys + self._meta.audio_keys:
for key in self._meta.video_keys:
if query_indices is not None and key in query_indices:
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
@@ -231,10 +213,10 @@ class DatasetReader:
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
"""Query dataset for indices across keys, skipping video and audio keys."""
"""Query dataset for indices across keys, skipping video keys."""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self._meta.video_keys or key in self._meta.audio_keys:
if key in self._meta.video_keys:
continue
relative_indices = (
q_idx
@@ -264,28 +246,6 @@ class DatasetReader:
return item
# TODO(CarolinePascal): add variable query durations
def _query_audio(
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
) -> dict[str, torch.Tensor]:
ep = self.meta.episodes[ep_idx]
item = {}
for audio_key, query_ts in query_timestamps.items():
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
# Thus we load the start timestamp of the episode on this mp4 and,
# shift the query timestamp accordingly.
from_timestamp = ep[f"audio/{audio_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
start_time_s = self.meta.features[audio_key]["info"].get("start_time_s", 0.0)
audio_chunk = decode_audio(
audio_path, shifted_query_ts, query_duration, start_time_s, self.audio_backend
)
item[audio_key] = audio_chunk.squeeze(0)
return item
def get_item(self, idx) -> dict:
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
@@ -305,12 +265,11 @@ class DatasetReader:
for key, val in query_result.items():
item[key] = val
if len(self._meta.video_keys) > 0 or len(self._meta.audio_keys) > 0:
if len(self._meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
item = {**video_frames, **audio_chunks, **item}
item = {**video_frames, **item}
if self._image_transforms is not None:
image_keys = self._meta.camera_keys

View File

@@ -31,7 +31,6 @@ import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.datasets.audio_utils import encode_audio
from lerobot.datasets.compute_stats import compute_episode_stats
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
@@ -49,17 +48,14 @@ from lerobot.datasets.io_utils import (
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
DEFAULT_RAW_AUDIO_PATH,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
concatenate_media_files,
concatenate_video_files,
encode_video_frames,
get_media_duration_in_s,
get_video_duration_in_s,
)
from lerobot.microphones.microphone import Microphone
from lerobot.microphones.utils import async_microphones_start_recording
logger = logging.getLogger(__name__)
@@ -148,10 +144,6 @@ class DatasetWriter:
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
return self._root / fpath
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
) -> None:
@@ -216,43 +208,11 @@ class DatasetWriter:
compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
elif self._meta.features[key]["dtype"] == "audio":
if (
self._meta.robot_type == "lekiwi"
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
self.episode_buffer[key].append(frame[key])
else: # Otherwise, only the audio file path is stored in the episode buffer
if frame_index == 0:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key
)
self.episode_buffer[key].append(str(audio_path))
else:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None:
"""
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
"""
audio_file = self._get_raw_audio_file_path(self._meta.total_episodes, "observation.audio." + microphone_key)
microphone.start_recording(output_file=audio_file)
def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None:
"""
Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files.
"""
output_files = []
for microphone_key in microphones:
output_files.append(
self._get_raw_audio_file_path(self._meta.total_episodes, "observation.audio." + microphone_key)
)
async_microphones_start_recording(microphones, output_files)
def save_episode(
self,
episode_data: dict | None = None,
@@ -281,19 +241,12 @@ class DatasetWriter:
for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif ft["dtype"] == "audio":
if (
self._meta.robot_type == "lekiwi"
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
continue
episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
has_video_keys = len(self._meta.video_keys) > 0
has_audio_keys = len(self._meta.audio_keys) > 0
use_streaming = self._streaming_encoder is not None and has_video_keys
use_batched_encoding = self._batch_encoding_size > 1
@@ -320,7 +273,7 @@ class DatasetWriter:
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
elif (has_video_keys or has_audio_keys) and not use_batched_encoding:
elif has_video_keys and not use_batched_encoding:
num_cameras = len(self._meta.video_keys)
if parallel_encoding and num_cameras > 1:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
@@ -356,28 +309,19 @@ class DatasetWriter:
for video_key in self._meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# TODO(Caroline): add parallel encoding for audio as well
for audio_key in self._meta.audio_keys:
ep_metadata.update(self._save_episode_audio(audio_key, episode_index))
# `meta.save_episode` need to be executed after encoding the videos
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
if (has_video_keys or has_audio_keys) and use_batched_encoding:
if has_video_keys and use_batched_encoding:
self._episodes_since_last_encoding += 1
if self._episodes_since_last_encoding == self._batch_encoding_size:
start_ep = self._meta.total_episodes - self._batch_encoding_size
end_ep = self._meta.total_episodes
if has_video_keys:
self._batch_save_episode_video(start_ep, end_ep)
if has_audio_keys:
self._batch_save_episode_audio(start_ep, end_ep)
self._batch_save_episode_video(start_ep, end_ep)
self._episodes_since_last_encoding = 0
if episode_data is None:
self.clear_episode_buffer(
delete_images=len(self._meta.image_keys) > 0, delete_audio=len(self._meta.audio_keys) > 0
)
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
"""Batch save videos for multiple episodes."""
@@ -424,59 +368,6 @@ class DatasetWriter:
episode_df.to_parquet(episode_df_path)
self._meta.episodes = load_episodes(self._root)
def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None:
"""
Batch save audio for multiple episodes.
Args:
start_episode: Starting episode index (inclusive)
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
"""
if end_episode is None:
end_episode = self._meta.total_episodes
logging.info(
f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}"
)
chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"]
file_idx = self._meta.episodes[start_episode]["data/file_index"]
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
episode_df = pd.read_parquet(episode_df_path)
for ep_idx in range(start_episode, end_episode):
logging.info(f"Encoding audio for episode {ep_idx}")
if (
self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
or self._meta.episodes[ep_idx]["data/file_index"] != file_idx
):
# The current episode is in a new chunk or file.
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
episode_df.to_parquet(episode_df_path)
self._meta.episodes = load_episodes(self._root)
# Load new episode dataframe
chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"]
file_idx = self._meta.episodes[ep_idx]["data/file_index"]
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)
episode_df = pd.read_parquet(episode_df_path)
# Save the current episode's video metadata to the dataframe
audio_ep_metadata = {}
for audio_key in self._meta.audio_keys:
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
audio_ep_metadata.pop("episode_index")
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
dtype_backend="pyarrow"
) # allows NaN values along with integers
episode_df = episode_df.combine_first(audio_ep_df)
episode_df.to_parquet(episode_df_path)
self._meta.episodes = load_episodes(self._root)
def _save_episode_data(self, episode_buffer: dict) -> dict:
"""Save episode data to a parquet file."""
# Use metadata features as the authoritative schema
@@ -554,7 +445,7 @@ class DatasetWriter:
ep_path = temp_path
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
ep_duration_in_s = get_video_duration_in_s(ep_path)
if (
episode_index == 0
@@ -594,7 +485,7 @@ class DatasetWriter:
shutil.move(str(ep_path), str(new_path))
latest_duration_in_s = 0.0
else:
concatenate_media_files(
concatenate_video_files(
[latest_path, ep_path],
latest_path,
)
@@ -616,91 +507,7 @@ class DatasetWriter:
}
return metadata
def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path:
"""
Use ffmpeg to convert raw audio files into m4a audio files.
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since audio encoding with ffmpeg is already using multithreading.
"""
temp_path = Path(tempfile.mkdtemp(dir=self._root)) / f"{audio_key}_{episode_index:03d}.m4a"
raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
encode_audio(raw_audio_file, temp_path, overwrite=True)
raw_audio_file.unlink()
return temp_path
def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict:
# Encode episode audio into a temporary audio file
ep_path = self._encode_temporary_episode_audio(audio_key, episode_index)
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
if (
episode_index == 0
or self._meta.latest_episode is None
or f"audio/{audio_key}/chunk_index" not in self._meta.latest_episode
):
# Initialize indices for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
old_chunk_idx = self._meta.episodes[-1][f"audio/{audio_key}/chunk_index"]
old_file_idx = self._meta.episodes[-1][f"audio/{audio_key}/file_index"]
chunk_idx, file_idx = update_chunk_file_indices(
old_chunk_idx, old_file_idx, self._meta.chunks_size
)
latest_duration_in_s = 0.0
new_path = self._root / self._meta.audio_path.format(
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
else:
# Retrieve information from the latest updated audio file using latest_episode
latest_ep = self._meta.latest_episode
chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0]
file_idx = latest_ep[f"audio/{audio_key}/file_index"][0]
latest_path = self._root / self._meta.audio_path.format(
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
)
latest_size_in_mb = get_file_size_in_mb(latest_path)
latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0]
if latest_size_in_mb + ep_size_in_mb >= self._meta.audio_files_size_in_mb:
# Move temporary episode audio to a new audio file in the dataset
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
new_path = self._root / self._meta.audio_path.format(
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
latest_duration_in_s = 0.0
else:
# Update latest audio file
concatenate_media_files(
[latest_path, ep_path],
latest_path,
)
# Remove temporary directory
shutil.rmtree(str(ep_path.parent))
# Update audio info (only needed when first episode is encoded since it reads from episode 0)
if episode_index == 0:
self._meta.update_audio_info(audio_key)
write_info(self._meta.info, self._meta.root) # ensure audio info always written properly
metadata = {
"episode_index": episode_index,
f"audio/{audio_key}/chunk_index": chunk_idx,
f"audio/{audio_key}/file_index": file_idx,
f"audio/{audio_key}/from_timestamp": latest_duration_in_s,
f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
}
return metadata
def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None:
def clear_episode_buffer(self, delete_images: bool = True) -> None:
"""Discard the current episode buffer and optionally delete temp images.
Args:
@@ -724,15 +531,6 @@ class DatasetWriter:
if img_dir.is_dir():
shutil.rmtree(img_dir)
if delete_audio:
episode_index = self.episode_buffer["episode_index"]
if isinstance(episode_index, np.ndarray):
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for audio_key in self._meta.audio_keys:
audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
if audio_file.is_file():
audio_file.unlink()
self.episode_buffer = self._create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
@@ -798,7 +596,7 @@ class DatasetWriter:
self._streaming_encoder.cancel_episode()
def cleanup_interrupted_episode(self, episode_index: int) -> None:
"""Remove temporary image and audio directories for an interrupted episode."""
"""Remove temporary image directories for an interrupted episode."""
for key in self._meta.video_keys:
img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0
@@ -809,14 +607,6 @@ class DatasetWriter:
)
shutil.rmtree(img_dir)
for key in self._meta.audio_keys:
audio_file = self._get_raw_audio_file_path(episode_index=episode_index, audio_key=key)
if audio_file.exists():
logger.debug(
f"Cleaning up interrupted episode audio for episode {episode_index}, microphone {key}"
)
audio_file.unlink()
def finalize(self) -> None:
"""Flush all pending work and release all resources.

View File

@@ -22,8 +22,6 @@ from PIL import Image as PILImage
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.utils import (
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
DEFAULT_AUDIO_PATH,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -49,7 +47,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video" or ft["dtype"] == "audio":
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -112,12 +110,7 @@ def hw_to_dataset_features(
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
cam_fts = {
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 3
}
mic_fts = {
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 2
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
features[prefix] = {
@@ -140,14 +133,6 @@ def hw_to_dataset_features(
"names": ["height", "width", "channels"],
}
for key, parameters in mic_fts.items():
features[f"{prefix}.audio.{key}"] = {
"dtype": "audio",
"shape": (len(parameters[1]),),
"names": ["channels"],
"info": {"sample_rate": parameters[0]},
}
_validate_feature_names(features)
return features
@@ -177,8 +162,6 @@ def build_dataset_frame(
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
elif ft["dtype"] == "audio":
frame[key] = values[key.removeprefix(f"{prefix}.audio.")]
return frame
@@ -212,10 +195,6 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif ft["dtype"] == "audio":
type = FeatureType.AUDIO
if len(shape) != 2:
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):
@@ -294,7 +273,6 @@ def create_empty_dataset_info(
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
audio_files_size_in_mb: int | None = None,
) -> dict:
"""Create a template dictionary for a new dataset's `info.json`.
@@ -304,10 +282,7 @@ def create_empty_dataset_info(
features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any.
chunks_size (int | None): The number of files per chunk.
data_files_size_in_mb (int | None): The maximum size per data file in MB.
video_files_size_in_mb (int | None): The maximum size per video file in MB.
audio_files_size_in_mb (int | None): The maximum size per audio file in MB.
Returns:
dict: A dictionary with the initial dataset metadata.
"""
@@ -320,12 +295,10 @@ def create_empty_dataset_info(
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB,
"fps": fps,
"splits": {},
"data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"audio_path": DEFAULT_AUDIO_PATH,
"features": features,
}
@@ -462,8 +435,6 @@ def validate_feature_dtype_and_shape(
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "audio":
return validate_feature_audio(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
@@ -530,33 +501,6 @@ def validate_feature_image_or_video(
return error_message
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
"""Validate a feature that is expected to be an audio frame.
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C,).
value: The audio data to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c = expected_shape
if (len(actual_shape) != 2 and len(actual_shape) != 1) or actual_shape[-1] != c[
-1
]: # The number of frames might be different
error_message += (
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n"
)
else:
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str) -> str:
"""Validate a feature that is expected to be a string.

View File

@@ -23,7 +23,6 @@ import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import soundfile as sf
import torch
from datasets import Dataset
from datasets.table import embed_table_storage
@@ -281,24 +280,6 @@ def load_image_as_numpy(
return img_array
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
"""Load an audio file from a path into a numpy array.
Args:
fpath (str | Path): Path to the audio file.
Returns:
np.ndarray: The audio as a numpy array.
"""
audio_data, _ = sf.read(fpath, dtype="float32")
# Fill missing channel dimension when loading mono audio data
if audio_data.ndim == 1:
audio_data = np.expand_dims(audio_data, axis=1)
return audio_data
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.

View File

@@ -54,9 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
download_audio: bool = True,
video_backend: str | None = None,
audio_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
streaming_encoding: bool = False,
@@ -93,7 +91,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
task-conditioned training.
- data (backed by datasets.Dataset), which reads values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
- audio (optional) from which audio is loaded to be synchronous with data from parquet files.
A typical LeRobotDataset looks like this from its root path:
.
@@ -119,37 +116,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
│ ├── info.json
│ ├── stats.json
│ └── tasks.parquet
── videos
├── observation.images.laptop
│ │ ├── chunk-000
│ │ │ ├── file-000.mp4
│ │ │ ├── file-001.mp4
│ │ │ └── ...
│ │ ├── chunk-001
│ │ │ └── ...
│ │ └── ...
│ ├── observation.images.phone
│ │ ├── chunk-000
│ │ │ ├── file-000.mp4
│ │ │ ├── file-001.mp4
│ │ │ └── ...
│ │ ├── chunk-001
│ │ │ └── ...
│ │ └── ...
│ └── ...
└── audio
├── observation.audio.laptop
── videos
├── observation.images.laptop
│ ├── chunk-000
│ │ ├── file-000.m4a
│ │ ├── file-001.m4a
│ │ ├── file-000.mp4
│ │ ├── file-001.mp4
│ │ └── ...
│ ├── chunk-001
│ │ └── ...
│ └── ...
├── observation.audio.phone
├── observation.images.phone
│ ├── chunk-000
│ │ ├── file-000.m4a
│ │ ├── file-001.m4a
│ │ ├── file-000.mp4
│ │ ├── file-001.mp4
│ │ └── ...
│ ├── chunk-001
│ │ └── ...
@@ -190,10 +169,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
download_audio (bool, optional): Flag to download the audio. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchcodec'.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
@@ -221,7 +198,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec()
self._audio_backend = audio_backend if audio_backend else "torchcodec"
self._batch_encoding_size = batch_encoding_size
self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
@@ -243,7 +219,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes=episodes,
tolerance_s=tolerance_s,
video_backend=self._video_backend,
audio_backend=self._audio_backend,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
)
@@ -252,7 +227,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if force_cache_sync or not self.reader.try_load():
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self._download(download_videos, download_audio)
self._download(download_videos)
self.reader.load_and_activate()
# Detect write-mode params for backward compatibility
@@ -306,7 +281,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes=self.episodes,
tolerance_s=self.tolerance_s,
video_backend=self._video_backend,
audio_backend=self._audio_backend,
delta_timestamps=self.delta_timestamps,
image_transforms=self.image_transforms,
)
@@ -386,14 +360,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._require_writer("add_frame")
self.writer.add_frame(frame)
def add_microphones_recordings(self, microphones: dict) -> None:
"""Add microphone recordings to the current episode buffer.
Delegates to :meth:`DatasetWriter.add_microphones_recordings`.
"""
self._require_writer("add_microphones_recordings")
self.writer.add_microphones_recordings(microphones)
def save_episode(self, episode_data: dict | None = None, parallel_encoding: bool = True) -> None:
"""Save the current episode buffer to disk.
@@ -518,7 +484,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
push_audio: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
@@ -548,8 +513,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
if not push_audio:
ignore_patterns.append("audio/")
hub_api = HfApi()
hub_api.create_repo(
@@ -590,15 +553,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def _download(self, download_videos: bool = True, download_audio: bool = True) -> None:
def _download(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version."""
ignore_patterns = None if download_videos else "videos/"
files = None
ignore_patterns = []
if not download_videos:
ignore_patterns.append("videos/")
if not download_audio:
ignore_patterns.append("audio/")
if self.episodes is not None:
# Reader is guaranteed to exist here (created in __init__ before _download)
files = self.reader.get_episodes_file_paths()
@@ -645,7 +603,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_processes: int = 0,
image_writer_threads: int = 0,
video_backend: str | None = None,
audio_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
metadata_buffer_size: int = 10,

View File

@@ -73,7 +73,6 @@ class ForwardCompatibilityError(CompatibilityError):
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file
INFO_PATH = "meta/info.json"
STATS_PATH = "meta/stats.json"
@@ -81,7 +80,6 @@ STATS_PATH = "meta/stats.json"
EPISODES_DIR = "meta/episodes"
DATA_DIR = "data"
VIDEO_DIR = "videos"
AUDIO_DIR = "audio"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
@@ -89,12 +87,7 @@ DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav"
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION = 1.0 # seconds
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"

View File

@@ -486,42 +486,42 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def concatenate_media_files(
input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True
def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
):
"""
Concatenate multiple media files (video & audio) into a single media file using pyav.
Concatenate multiple video files into a single video file using pyav.
This function takes a list of input media file paths and concatenates them into a single
output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast
This function takes a list of video input file paths and concatenates them into a single
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
concatenation without re-encoding.
Args:
input_media_paths: Ordered list of input media file paths to concatenate.
output_media_path: Path to the output media file.
overwrite: Whether to overwrite the output media file if it already exists. Default is True.
input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
Note:
- Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use.
- Uses ffmpeg's concat demuxer which requires all input media files to have the same
- Creates a temporary directory for intermediate files that is cleaned up after use.
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
codec, resolution, and frame rate for proper concatenation.
"""
output_media_path = Path(output_media_path)
output_video_path = Path(output_video_path)
if output_media_path.exists() and not overwrite:
logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.")
if output_video_path.exists() and not overwrite:
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
return
output_media_path.parent.mkdir(parents=True, exist_ok=True)
output_video_path.parent.mkdir(parents=True, exist_ok=True)
if len(input_media_paths) == 0:
raise FileNotFoundError("No input media paths provided.")
if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.")
# Create a temporary .ffconcat file to list the input media paths
# Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
for input_path in input_media_paths:
for input_path in input_video_paths:
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
tmp_concatenate_file.flush()
tmp_concatenate_path = tmp_concatenate_file.name
@@ -531,12 +531,11 @@ def concatenate_media_files(
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
) # safe = 0 allows absolute paths as well as relative paths
# Using an intermediate container to store the concatenated media file is necessary to avoid inplace concatenation read-write race conditions.
with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file:
tmp_output_media_path = tmp_named_file.name
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
output_container = av.open(
tmp_output_media_path, mode="w", options={"movflags": "faststart"}
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
) # faststart is to move the metadata to the beginning of the file to speed up loading
# Replicate input streams in output container
@@ -551,7 +550,6 @@ def concatenate_media_files(
stream_map[input_stream.index].time_base = input_stream.time_base
# Demux + remux packets (no re-encode)
last_dts = None
for packet in input_container.demux():
# Skip packets from un-mapped streams
if packet.stream.index not in stream_map:
@@ -560,16 +558,6 @@ def concatenate_media_files(
# Skip demux flushing packets
if packet.dts is None:
continue
else:
# Enforce strictly increasing decoding timestamps (DTS)
if last_dts is not None and packet.dts <= last_dts:
shift = last_dts - packet.dts + 1
packet.dts += shift
packet.pts += shift # Presenting timestamps (PTS) are the same as DTS here
logging.warning(
f"Non-monotonic DTS; previous: {last_dts}, current: {packet.dts - shift}; changing to {packet.dts}. This may result in incorrect timestamps in the output file."
)
last_dts = packet.dts
output_stream = stream_map[packet.stream.index]
packet.stream = output_stream
@@ -577,7 +565,7 @@ def concatenate_media_files(
input_container.close()
output_container.close()
shutil.move(tmp_output_media_path, output_media_path)
shutil.move(tmp_output_video_path, output_video_path)
Path(tmp_concatenate_path).unlink()
@@ -959,6 +947,38 @@ with warnings.catch_warnings():
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.WARNING)
# Getting audio stream information
audio_info = {}
with av.open(str(video_path), "r") as audio_file:
try:
audio_stream = audio_file.streams.audio[0]
except IndexError:
# Reset logging level
av.logging.restore_default_callback()
return {"has_audio": False}
audio_info["audio.channels"] = audio_stream.channels
audio_info["audio.codec"] = audio_stream.codec.canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info["audio.bit_rate"] = audio_stream.bit_rate
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info["audio.bit_depth"] = audio_stream.format.bits
audio_info["audio.channel_layout"] = audio_stream.layout.name
audio_info["has_audio"] = True
# Reset logging level
av.logging.restore_default_callback()
return audio_info
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.WARNING)
@@ -988,6 +1008,9 @@ def get_video_info(video_path: Path | str) -> dict:
# Reset logging level
av.logging.restore_default_callback()
# Adding audio stream information
video_info.update(**get_audio_info(video_path))
return video_info
@@ -1002,22 +1025,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
raise ValueError("Unknown format")
def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float:
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a media file (video & audio) in seconds using PyAV.
Get the duration of a video file in seconds using PyAV.
Args:
media_path: Path to the media file.
video_path: Path to the video file.
Returns:
Duration of the media file in seconds.
Duration of the video in seconds.
"""
with av.open(str(media_path)) as container:
# Get the first stream
stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0]
with av.open(str(video_path)) as container:
# Get the first video stream
video_stream = container.streams.video[0]
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
if stream.duration is not None:
duration = float(stream.duration * stream.time_base)
if video_stream.duration is not None:
duration = float(video_stream.duration * video_stream.time_base)
else:
# Fallback to container duration if stream duration is not available
duration = float(container.duration / av.time_base)
@@ -1026,12 +1049,12 @@ def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -
class VideoEncodingManager:
"""
Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur.
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
This manager handles:
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image and audio files from interrupted episodes
- Removing empty image and audio directories
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args:
dataset: The LeRobotDataset instance
@@ -1068,16 +1091,4 @@ class VideoEncodingManager:
else:
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
# Clean up any remaining audio directory if it's empty
audio_dir = self.dataset.root / "raw_audio"
# Check for any remaining WAV files
wav_files = list(audio_dir.rglob("*.wav"))
if len(wav_files) == 0:
# Only remove the raw_audio directory if no WAV files remain
if audio_dir.exists():
shutil.rmtree(audio_dir)
logging.debug("Cleaned up empty audio directory")
else:
logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files")
return False # Don't suppress the original exception

View File

@@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import importlib
from dataclasses import dataclass, field, fields
from typing import Any
import draccus
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
@@ -67,6 +72,45 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def gym_kwargs(self) -> dict:
raise NotImplementedError()
def create_envs(
self,
n_envs: int,
use_async_envs: bool = True,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Create {suite: {task_id: VectorEnv}}.
Default: single-task env via gym.make(). Multi-task benchmarks override.
AsyncVectorEnv is the default for n_envs > 1; auto-downgraded to Sync for n_envs=1.
"""
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
if self.gym_id not in gym_registry:
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
try:
importlib.import_module(self.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{self.package_name}' required for env '{self.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if self.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
)
def _make_one():
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
return {self.type: {0: vec}}
def get_env_processors(self):
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
@dataclass
class HubEnvConfig(EnvConfig):
@@ -345,6 +389,32 @@ class LiberoEnv(EnvConfig):
kwargs["task_ids"] = self.task_ids
return kwargs
def create_envs(self, n_envs: int, use_async_envs: bool = True):
from lerobot.envs.libero import create_libero_envs
if self.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
return create_libero_envs(
task=self.task,
n_envs=n_envs,
camera_name=self.camera_name,
init_states=self.init_states,
gym_kwargs=self.gym_kwargs,
env_cls=env_cls,
control_mode=self.control_mode,
episode_length=self.episode_length,
)
def get_env_processors(self):
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
return (
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
@EnvConfig.register_subclass("metaworld")
@dataclass
@@ -387,6 +457,19 @@ class MetaworldEnv(EnvConfig):
"render_mode": self.render_mode,
}
def create_envs(self, n_envs: int, use_async_envs: bool = True):
from lerobot.envs.metaworld import create_metaworld_envs
if self.task is None:
raise ValueError("MetaWorld requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
return create_metaworld_envs(
task=self.task,
n_envs=n_envs,
gym_kwargs=self.gym_kwargs,
env_cls=env_cls,
)
@EnvConfig.register_subclass("isaaclab_arena")
@dataclass
@@ -454,3 +537,18 @@ class IsaaclabArenaEnv(HubEnvConfig):
@property
def gym_kwargs(self) -> dict:
return {}
def get_env_processors(self):
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
if not state_keys and not camera_keys:
raise ValueError("At least one of state_keys or camera_keys must be specified.")
return (
PolicyProcessorPipeline(
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
),
PolicyProcessorPipeline(steps=[]),
)

View File

@@ -13,96 +13,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from __future__ import annotations
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
from lerobot.envs.configs import EnvConfig, HubEnvConfig
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
if env_type == "aloha":
return AlohaEnv(**kwargs)
elif env_type == "pusht":
return PushtEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
try:
cls = EnvConfig.get_choice_class(env_type)
except KeyError as err:
raise ValueError(
f"Environment type '{env_type}' is not registered. "
f"Available: {list(EnvConfig.get_known_choices().keys())}"
) from err
return cls(**kwargs)
def make_env_pre_post_processors(
env_cfg: EnvConfig,
policy_cfg: PreTrainedConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
policy_cfg: Any,
) -> tuple[Any, Any]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
Returns a tuple of (preprocessor, postprocessor). By default, delegates to
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override
stays here because it depends on the *policy* config, not the env config.
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
if isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
# Parse comma-separated keys (handle None for state-based policies)
if env_cfg.state_keys:
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
else:
state_keys = ()
if env_cfg.camera_keys:
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
else:
camera_keys = ()
if not state_keys and not camera_keys:
raise ValueError("At least one of state_keys or camera_keys must be specified.")
preprocessor_steps.append(
IsaaclabArenaProcessorStep(
state_keys=state_keys,
camera_keys=camera_keys,
)
)
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
return env_cfg.get_env_processors()
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
use_async_envs: bool = False,
use_async_envs: bool = True,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
@@ -163,57 +119,4 @@ def make_env(
if n_envs < 1:
raise ValueError("`n_envs` must be at least 1")
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs
if cfg.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
return create_libero_envs(
task=cfg.task,
n_envs=n_envs,
camera_name=cfg.camera_name,
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
if cfg.task is None:
raise ValueError("MetaWorld requires a task to be specified")
return create_metaworld_envs(
task=cfg.task,
n_envs=n_envs,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
def _make_one():
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
# normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha"
return {suite_name: {0: vec}}
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs)

View File

@@ -150,7 +150,17 @@ class LiberoEnv(gym.Env):
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
# Extract task metadata without allocating GPU resources (safe before fork).
task = task_suite.get_task(task_id)
self.task = task.name
self.task_description = task.language
self._task_bddl_file = os.path.join(
get_libero_path("bddl_files"), task.problem_folder, task.bddl_file
)
self._env: OffScreenRenderEnv | None = (
None # deferred — created on first reset() inside the worker subprocess
)
default_steps = 500
self._max_episode_steps = (
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
@@ -221,28 +231,32 @@ class LiberoEnv(gym.Env):
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
)
def _ensure_env(self) -> None:
"""Create the underlying OffScreenRenderEnv on first use.
Called inside the worker subprocess after fork(), so each worker gets
its own clean EGL context rather than inheriting a stale one from the
parent process (which causes EGL_BAD_CONTEXT crashes with AsyncVectorEnv).
"""
if self._env is not None:
return
env = OffScreenRenderEnv(
bddl_file_name=self._task_bddl_file,
camera_heights=self.observation_height,
camera_widths=self.observation_width,
)
env.reset()
self._env = env
def render(self):
self._ensure_env()
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
task = task_suite.get_task(task_id)
self.task = task.name
self.task_description = task.language
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
env_args = {
"bddl_file_name": task_bddl_file,
"camera_heights": self.observation_height,
"camera_widths": self.observation_width,
}
env = OffScreenRenderEnv(**env_args)
env.reset()
return env
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation:
assert self._env is not None, "_format_raw_obs called before _ensure_env()"
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
@@ -294,6 +308,7 @@ class LiberoEnv(gym.Env):
)
def reset(self, seed=None, **kwargs):
self._ensure_env()
super().reset(seed=seed)
self._env.seed(seed)
raw_obs = self._env.reset()
@@ -320,6 +335,8 @@ class LiberoEnv(gym.Env):
return observation, info
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
self._ensure_env()
assert self._env is not None
if action.ndim != 1:
raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), "
@@ -350,7 +367,8 @@ class LiberoEnv(gym.Env):
return observation, reward, terminated, truncated, info
def close(self):
self._env.close()
if self._env is not None:
self._env.close()
def _make_env_fns(

View File

@@ -97,8 +97,9 @@ class MetaworldEnv(gym.Env):
self.visualization_height = visualization_height
self.camera_name = camera_name
self._env = self._make_envs_task(self.task)
self._max_episode_steps = self._env.max_path_length
self._env_name = self.task # already stripped of "metaworld-" prefix above
self._env = None # deferred — created on first reset() inside the worker subprocess
self._max_episode_steps = 500 # MT1 environments always have max_path_length=500
self.task_description = TASK_DESCRIPTIONS[self.task]
self.expert_policy = TASK_POLICY_MAPPING[self.task]()
@@ -136,6 +137,24 @@ class MetaworldEnv(gym.Env):
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
def _ensure_env(self) -> None:
"""Create the underlying MetaWorld env on first use.
Called inside the worker subprocess after fork(), so each worker gets
its own clean rendering context rather than inheriting a stale one from
the parent process (which causes crashes with AsyncVectorEnv).
"""
if self._env is not None:
return
mt1 = metaworld.MT1(self._env_name, seed=42)
env = mt1.train_classes[self._env_name](render_mode="rgb_array", camera_name=self.camera_name)
env.set_task(mt1.train_tasks[0])
if self.camera_name == "corner2":
env.model.cam_pos[2] = [0.75, 0.075, 0.7]
env.reset()
env._freeze_rand_vec = False # otherwise no randomization
self._env = env
def render(self) -> np.ndarray:
"""
Render the current environment frame.
@@ -143,26 +162,13 @@ class MetaworldEnv(gym.Env):
Returns:
np.ndarray: The rendered RGB image from the environment.
"""
self._ensure_env()
image = self._env.render()
if self.camera_name == "corner2":
# Images from this camera are flipped — correct them
image = np.flip(image, (0, 1))
return image
def _make_envs_task(self, env_name: str):
mt1 = metaworld.MT1(env_name, seed=42)
env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name)
env.set_task(mt1.train_tasks[0])
if self.camera_name == "corner2":
env.model.cam_pos[2] = [
0.75,
0.075,
0.7,
] # corner2 position, similar to https://arxiv.org/pdf/2206.14244
env.reset()
env._freeze_rand_vec = False # otherwise no randomization
return env
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation:
image = None
if self._env is not None:
@@ -209,6 +215,7 @@ class MetaworldEnv(gym.Env):
observation (RobotObservation): The initial formatted observation.
info (Dict[str, Any]): Additional info about the reset state.
"""
self._ensure_env()
super().reset(seed=seed)
raw_obs, info = self._env.reset(seed=seed)
@@ -232,6 +239,7 @@ class MetaworldEnv(gym.Env):
truncated (bool): Whether the episode was truncated due to a time limit.
info (Dict[str, Any]): Additional environment info.
"""
self._ensure_env()
if action.ndim != 1:
raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), "
@@ -263,7 +271,8 @@ class MetaworldEnv(gym.Env):
return observation, reward, terminated, truncated, info
def close(self):
self._env.close()
if self._env is not None:
self._env.close()
# ---- Main API ----------------------------------------------------------------

View File

@@ -1,17 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configs import MicrophoneConfig
from .microphone import Microphone
from .utils import make_microphones_from_configs

View File

@@ -1,140 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from pathlib import Path
from threading import Barrier
from typing import Any
import numpy as np
from .configs import MicrophoneConfig
class Microphone(abc.ABC):
"""Base class for microphone implementations.
Defines a standard interface for microphone operations across different backends.
Subclasses must implement all abstract methods.
Manages basic microphone properties (sample rate, channels) and core operations:
- Connection/disconnection
- Start/stop recording
- Audio chunk reading
Attributes:
sample_rate (int | None): Configured sample rate in Hz
channels (list[int] | None): List of channel numbers to record
Example:
class MyMicrophone(Microphone):
def __init__(self, config): ...
@property
def is_connected(self) -> bool: ...
def connect(self): ...
# Plus other required methods
"""
def __init__(self, config: MicrophoneConfig):
"""Initialize the microphone with the given configuration.
Args:
config: Microphone configuration containing sample rate and channels.
"""
self.sample_rate: int | None = config.sample_rate
self.channels: list[int] | None = config.channels
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if the microphone is currently connected.
Returns:
bool: True if the microphone is connected and ready to start recording,
False otherwise.
"""
pass
@property
@abc.abstractmethod
def is_recording(self) -> bool:
"""Check if the microphone is currently recording.
Returns:
bool: True if the microphone is recording, False otherwise.
"""
pass
@property
@abc.abstractmethod
def is_writing(self) -> bool:
"""Check if the microphone is currently writing to a file.
Returns:
bool: True if the microphone is writing to a file, False otherwise.
"""
pass
@staticmethod
@abc.abstractmethod
def find_microphones() -> list[dict[str, Any]]:
"""Detects available microphones connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected microphone.
"""
pass
@abc.abstractmethod
def connect(self) -> None:
"""Establish connection to the microphone."""
pass
@abc.abstractmethod
def start_recording(
self,
output_file: str | Path | None = None,
multiprocessing: bool | None = False,
overwrite: bool | None = True,
barrier: Barrier | None = None,
) -> None:
"""Start recording audio from the microphone.
Args:
output_file: Optional path to save the recorded audio.
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
overwrite: If True, overwrites existing files at output_file path.
barrier: If not None, ensures that multiple microphones start recording at the same time.
"""
pass
@abc.abstractmethod
def read(self) -> np.ndarray:
"""Capture and return a single audio chunk from the microphone.
Returns:
np.ndarray: Captured audio chunk as a numpy array.
"""
pass
@abc.abstractmethod
def stop_recording(self) -> None:
"""Stop recording audio from the microphone."""
pass
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnect the microphone and release any resources."""
pass

View File

@@ -1,16 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_portaudio import PortAudioMicrophoneConfig
from .microphone_portaudio import PortAudioMicrophone

View File

@@ -1,41 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import MicrophoneConfig
@MicrophoneConfig.register_subclass("portaudio")
@dataclass
class PortAudioMicrophoneConfig(MicrophoneConfig):
"""Configuration class for PortAudio-based microphone devices.
This class provides configuration options for microphones accessed through PortAudio with the sounddevice Python package.
including device index, sample rate and channels.
Example configurations:
```python
# Basic configurations
PortAudioMicrophoneConfig(0, 16000, [1]) # Device index 0, 16000Hz, mono
PortAudioMicrophoneConfig(1, 44100, [1, 2]) # Device index 1, 44100Hz, stereo
```
Attributes:
microphone_index: Device index for the microphone.
sample_rate: Sample rate in Hz for the microphone.
channels: List of channel numbers to use for the microphone.
"""
microphone_index: int

View File

@@ -1,394 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import time
from collections.abc import Callable
from threading import Event, Thread
from typing import Any
import numpy as np
from sounddevice import PortAudioError
from lerobot.utils.robot_utils import precise_sleep
# --- Interface definitions for InputStream ---
class IInputStream(abc.ABC):
@abc.abstractmethod
def __init__(
self,
samplerate: float | None = None,
blocksize: int | None = None,
device: int | str | None = None,
channels: int | None = None,
dtype: str | np.dtype | None = None,
latency: float | str | None = None,
callback: Callable[[Any, int, Any, Any], None] | None = None,
):
pass
@abc.abstractmethod
def start(self) -> None:
pass
@abc.abstractmethod
def stop(self) -> None:
pass
@abc.abstractmethod
def close(self) -> None:
pass
class ISounddeviceSDK(abc.ABC):
"""Interface defining the contract for the Sounddevice SDK."""
InputStream: type[IInputStream]
@abc.abstractmethod
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
pass
# --- Real SDK Adapter ---
class SounddeviceSDKAdapter(ISounddeviceSDK):
"""Adapts the real sounddevice library to the ISounddeviceSDK interface."""
_sounddevice = None
def __init__(self):
try:
import sounddevice
SounddeviceSDKAdapter._sounddevice = sounddevice
except ImportError as e:
raise ImportError("sounddevice library not found") from e
# --- Inner Class Implementation ---
class RealInputStream(IInputStream):
def __init__(
self,
samplerate: int | None = None,
blocksize: int | None = None,
device: int | None = None,
channels: int | None = None,
dtype: str | np.dtype | None = None,
latency: float | str | None = None,
callback: Callable[[Any, int, Any, Any], None] | None = None,
):
import sounddevice
self._input_stream = sounddevice.InputStream(
samplerate=samplerate,
blocksize=blocksize,
device=device,
channels=channels,
dtype=dtype,
latency=latency,
callback=callback,
)
def start(self) -> None:
self._input_stream.start()
def stop(self) -> None:
self._input_stream.stop()
def close(self) -> None:
self._input_stream.close()
def __del__(self):
self._input_stream.stop()
self._input_stream.close()
@property
def active(self) -> bool:
return self._input_stream.active
@property
def stopped(self) -> bool:
return self._input_stream.stopped
@property
def closed(self) -> bool:
return self._input_stream.closed
InputStream = RealInputStream
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
return SounddeviceSDKAdapter._sounddevice.query_devices(device, kind)
# Emulates a 48kHz stereo microphone
VALID_DTYPE = {
"float32",
"int32",
"int16",
"int8",
"uint8",
np.float32,
np.int32,
np.int16,
np.int8,
np.uint8,
}
VALID_LATENCY = {"low", "high"}
VALID_DEVICES = [
{
"index": 0,
"name": "Built-in Microphone",
"hostapi": 0,
"max_input_channels": 2,
"max_output_channels": 0,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.001,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.01,
"default_samplerate": 48000.0,
},
{
"index": 1,
"name": "Built-in Output",
"hostapi": 0,
"max_input_channels": 0,
"max_output_channels": 2,
"default_low_input_latency": 0.04,
"default_low_output_latency": 0.04,
"default_high_input_latency": 0.12,
"default_high_output_latency": 0.12,
"default_samplerate": 48000.0,
},
{
"index": 2,
"name": "USB Audio Device",
"hostapi": 0,
"max_input_channels": 1,
"max_output_channels": 0,
"default_low_input_latency": 0.03,
"default_low_output_latency": 0.01,
"default_high_input_latency": 0.04,
"default_high_output_latency": 0.03,
"default_samplerate": 16000.0,
},
]
# -- Fake SDK Adapter ---
class FakeSounddeviceSDKAdapter(ISounddeviceSDK):
"""Implements the ISounddeviceSDK interface with fake behaviour for testing."""
# --- Inner Class Implementation ---
class FakeInputStream(IInputStream):
def __init__(
self,
samplerate: float | None = None,
blocksize: int | None = None,
device: int | str | None = None,
channels: int | None = None,
dtype: str | None = None,
latency: str | None = None,
callback: Callable[[Any, int, Any, Any], None] | None = None,
):
self.samplerate = samplerate
self.blocksize = blocksize
self.device = device
self.channels = channels
self.dtype = dtype
self.latency = latency
self.callback = callback
self._validate_settings()
self._active = False
self._closed = False
if self.callback is not None:
self._streaming_thread = Thread(target=self._streaming_loop, daemon=True)
self._streaming_thread_stop_event = Event()
@property
def active(self) -> bool:
"""True when the stream is active, False otherwise."""
return self._active
@property
def stopped(self) -> bool:
"""True when the stream is stopped, False otherwise."""
return not self._active
@property
def closed(self) -> bool:
"""True after a call to close(), False otherwise."""
return self._closed
def _get_device_info(self):
"""Returns the device info for the device."""
for device in VALID_DEVICES:
if (isinstance(self.device, int) and device["index"] == self.device) or (
isinstance(self.device, str) and device["name"] == self.device
):
return device
raise PortAudioError(f"No input device matching {self.device}")
def _validate_device(self):
"""Validates the device against the valid devices."""
valid_device_indices = [device["index"] for device in VALID_DEVICES]
valid_device_names = [device["name"] for device in VALID_DEVICES]
if self.device is not None:
if isinstance(self.device, (int, str)):
# Check if device index is valid
if isinstance(self.device, int) and self.device not in valid_device_indices:
raise PortAudioError(f"Error querying device {self.device}")
# Check if device name is valid
if isinstance(self.device, str) and self.device not in valid_device_names:
raise PortAudioError(f"No input device matching {self.device}")
else:
raise PortAudioError(f"Device must be int or str, got {type(self.device)}")
else:
# Default to first input device
input_devices = [d for d in VALID_DEVICES if d["max_input_channels"] > 0]
if input_devices:
self.device = input_devices[0]["index"]
def _validate_samplerate(self):
"""Validates the samplerate against the device's maximum samplerate."""
device_info = self._get_device_info()
if self.samplerate is None:
self.samplerate = device_info["default_samplerate"]
elif self.samplerate > device_info["default_samplerate"] or self.samplerate < 1000:
raise PortAudioError("Error opening InputStream: Invalid sample rate")
def _validate_channels(self):
"""Validates the channels against the device's maximum channels."""
device_info = self._get_device_info()
if self.channels is None:
self.channels = device_info["max_input_channels"]
elif self.channels > device_info["max_input_channels"] or self.channels < 1:
raise PortAudioError("Error opening InputStream: Invalid number of channels")
def _validate_dtype(self):
"""Validates the dtype against the valid dtypes."""
if self.dtype is not None:
if self.dtype not in VALID_DTYPE:
raise PortAudioError("Invalid input sample format")
else:
self.dtype = "float32" # Default dtype
def _validate_latency(self):
"""Validates the latency against the valid latencies."""
if self.latency is not None:
if self.latency not in VALID_LATENCY:
raise PortAudioError("Invalid latency")
else:
self.latency = "low" # Default latency
if isinstance(self.latency, str):
device_info = self._get_device_info()
if self.latency == "low":
self.latency = device_info["default_low_input_latency"]
elif self.latency == "high":
self.latency = device_info["default_high_input_latency"]
def _validate_settings(self):
"""Validates the input parameters against available devices and valid options."""
self._validate_device()
self._validate_samplerate()
self._validate_channels()
self._validate_dtype()
self._validate_latency()
def _simulated_audio_data(self) -> np.ndarray:
"""Generates a simulated audio signal for testing purposes with proper value ranges."""
duration_samples = int(self.samplerate * self.latency)
# Generate output according to dtype
if self.dtype in {"float32", np.float32}:
# Generate values between -1 and 1 for float32
data = np.random.uniform(-1.0, 1.0, (duration_samples, self.channels)).astype(self.dtype)
else:
# Use np.iinfo to get proper range for integer types
info = np.iinfo(self.dtype)
data = np.random.randint(
info.min, info.max + 1, (duration_samples, self.channels), dtype=self.dtype
)
return data
def _streaming_loop(self):
if self.callback is not None:
while not self._streaming_thread_stop_event.is_set():
precise_sleep(self.latency)
tmp_data = self._simulated_audio_data()
self.callback(
tmp_data,
len(tmp_data),
time.perf_counter(),
None,
)
def start(self) -> None:
"""Start the fake input stream."""
if not self.active and self.callback is not None:
self._streaming_thread.start()
self._active = True
def stop(self) -> None:
"""Stop the fake input stream."""
if self.callback is not None:
self._streaming_thread_stop_event.set()
self._streaming_thread.join()
self._active = False
def close(self) -> None:
"""Close the fake input stream."""
if self.active and self.callback is not None:
self.stop()
self._active = False
self._closed = True
def __del__(self):
self.close()
InputStream = FakeInputStream
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
"""Returns a realistic list of audio devices including speakers and microphones."""
if device is not None:
# Return specific device
for valid_device in VALID_DEVICES:
if (isinstance(device, int) and valid_device["index"] == device) or (
isinstance(device, str) and valid_device["name"] == device
):
return valid_device
raise PortAudioError(f"Error querying device {device}")
elif kind is not None:
for valid_device in VALID_DEVICES:
if (
valid_device["max_input_channels"] > 0
and kind == "input"
or valid_device["max_output_channels"] > 0
and kind == "output"
):
return valid_device
raise PortAudioError(f"No {kind} device found")
return VALID_DEVICES

View File

@@ -1,566 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the PortAudioMicrophone class for capturing audio from microphones using the PortAudio library through the sounddevice Python package.
"""
import logging
import time
from multiprocessing import (
Event as process_Event,
JoinableQueue as process_Queue,
Process,
)
from pathlib import Path
from queue import Empty
from threading import Barrier, Event, Event as thread_Event, Thread
from typing import Any
import numpy as np
from soundfile import SoundFile
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
from lerobot.utils.errors import (
DeviceAlreadyConnectedError,
DeviceAlreadyRecordingError,
DeviceNotConnectedError,
DeviceNotRecordingError,
)
from lerobot.utils.shared_array import SharedArray
from ..microphone import Microphone
from .configuration_portaudio import PortAudioMicrophoneConfig
logger = logging.getLogger(__name__)
class PortAudioMicrophone(Microphone):
"""
The PortAudioMicrophone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
A PortAudioMicrophone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
Example of usage:
```python
from lerobot.common.robot_devices.microphones.configs import PortAudioMicrophoneConfig
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
microphone = PortAudioMicrophone(config)
microphone.connect()
microphone.start_recording("some/output/file.wav")
...
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
...
microphone.stop_recording()
microphone.disconnect()
```
"""
def __init__(self, config: PortAudioMicrophoneConfig, sounddevice_sdk: ISounddeviceSDK = None):
"""
Initializes the PortAudioMicrophone instance.
Args:
config: The configuration settings for the microphone.
"""
super().__init__(config)
if sounddevice_sdk is None:
self.sounddevice_sdk = SounddeviceSDKAdapter()
else:
self.sounddevice_sdk = sounddevice_sdk
# Microphone index
self.microphone_index = config.microphone_index
# Input audio recording process and events
self.record_process = None
self.record_stop_event = process_Event()
self.record_start_event = process_Event()
self.record_close_event = process_Event()
self.record_is_started_event = process_Event()
self.audio_callback_start_event = process_Event()
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
self.write_queue = process_Queue()
# SharedArray to store audio from the recording process.
self.read_shared_array = None
self.local_read_shared_array = None
# Thread/Process to handle data writing in a separate thread/process (safely)
self.write_thread = None
self.write_stop_event = None
self.write_is_started_event = None
self.logs = {}
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.microphone_index})"
@property
def is_connected(self) -> bool:
return self.record_process is not None and self.record_process.is_alive()
@property
def is_recording(self) -> bool:
return self.record_is_started_event.is_set()
@property
def is_writing(self) -> bool:
return self.write_thread is not None and self.write_is_started_event.is_set()
@staticmethod
def find_microphones(
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
) -> list[dict[str, Any]] | dict[str, Any]:
"""
Detects available microphones connected to the system.
Args:
device: The device to find microphones for. If None, all microphones are found.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
"""
if sounddevice_sdk is None:
sounddevice_sdk = SounddeviceSDKAdapter()
found_microphones_info = []
devices = sounddevice_sdk.query_devices()
for d in devices:
if d["max_input_channels"] > 0:
microphone_info = {
"index": d["index"],
"name": d["name"],
"sample_rate": int(d["default_samplerate"]),
"channels": np.arange(1, d["max_input_channels"] + 1),
}
if device is None or (
(isinstance(device, int) and d["index"] == device)
or (isinstance(device, str) and d["name"] == device)
):
found_microphones_info.append(microphone_info)
if device is not None:
if len(found_microphones_info) == 0:
raise RuntimeError(f"No microphone found for device {device}")
else:
return found_microphones_info[0]
if len(found_microphones_info) == 0:
logger.warning("No microphone found !")
return found_microphones_info
def _configure_capture_settings(self) -> None:
"""
Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone.
This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream.
Raises:
RuntimeError: If one of the specified settings is not compatible with the microphone.
DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"Cannot configure settings for {self} as it is already connected."
)
self._validate_microphone_index()
self._validate_sample_rate()
self._validate_channels()
def _validate_microphone_index(self) -> None:
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
try:
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
except RuntimeError as e:
raise RuntimeError(
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
) from e
def _validate_sample_rate(self) -> None:
"""Validates the sample rate against the actual microphone's default sample rate."""
actual_sample_rate = PortAudioMicrophone.find_microphones(
self.microphone_index, self.sounddevice_sdk
)["sample_rate"]
if self.sample_rate is not None:
try:
self.sample_rate = int(self.sample_rate)
except ValueError as e:
raise RuntimeError(
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
) from e
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
raise RuntimeError(
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
)
else:
if self.sample_rate < actual_sample_rate:
logger.warning(
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
)
else:
self.sample_rate = actual_sample_rate
def _validate_channels(self) -> None:
"""Validates the channels against the actual microphone's maximum input channels."""
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
"channels"
]
if self.channels is not None and len(self.channels) > 0:
if not all(channel in actual_channels for channel in self.channels):
raise RuntimeError(
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
)
else:
self.channels = actual_channels
# Get channels index instead of number for slicing
self.channels_index = np.array(self.channels) - 1
def connect(self) -> None:
"""
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
self._configure_capture_settings()
# Create or reset queue and shared array
self.read_shared_array = SharedArray(
shape=(self.sample_rate * 10, len(self.channels)),
dtype=np.dtype("float32"),
)
self.local_read_shared_array = self.read_shared_array.get_local_array()
self.write_queue = process_Queue()
# Reset events
self.record_start_event.clear()
self.record_stop_event.clear()
self.record_close_event.clear()
self.record_is_started_event.clear()
self.audio_callback_start_event.clear()
# Create and start an audio input stream with a recording callback
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
process_init_event = process_Event()
self.record_process = Process(
target=self._record_process,
args=(
self.microphone_index,
self.sample_rate,
self.channels,
process_init_event,
self.record_start_event,
self.record_stop_event,
self.record_close_event,
self.record_is_started_event,
self.audio_callback_start_event,
self.write_queue,
self.read_shared_array,
self.sounddevice_sdk,
),
)
self.record_process.daemon = True
self.record_process.start()
is_init = process_init_event.wait(
timeout=5.0
) # Wait for the recording process to be started, and to potentially raise an error on failure.
if not self.is_connected or not is_init:
raise RuntimeError(f"Error connecting microphone {self.microphone_index}.")
logger.info(f"{self} connected.")
def disconnect(self) -> None:
"""
Disconnects the microphone and stops the recording.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording:
self.stop_recording()
self.record_close_event.set()
self.read_shared_array.delete()
self.write_queue.close()
self.record_process.join()
if self.is_connected:
raise RuntimeError(f"Error disconnecting microphone {self.microphone_index}.")
logger.info(f"{self} disconnected.")
def _read(self) -> np.ndarray:
"""
Thread/Process-safe callback to read available audio data
"""
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
def read(self) -> np.ndarray:
"""
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording:
raise RuntimeError(f"Microphone {self.microphone_index} is not recording.")
start_time = time.perf_counter()
audio_readings = self._read()
# log the number of seconds it took to read the audio chunk
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
# log the utc time at which the audio chunk was received
self.logs["timestamp_utc"] = time.perf_counter()
return audio_readings
@staticmethod
def _record_process(
microphone_index,
sample_rate,
channels,
process_init_event,
record_start_event,
record_stop_event,
record_close_event,
record_is_started_event,
audio_callback_start_event,
write_queue,
read_shared_array,
sounddevice_sdk,
) -> None:
"""
Process callback used to create an unpickable sounddevice audio input stream with a recording callback and start, stop and close it based on multiprocessing events.
"""
channels_index = np.array(channels) - 1
local_read_shared_array = read_shared_array.get_local_array()
def audio_callback(indata, frames, timestamp, status) -> None:
"""
Low-level sounddevice callback.
"""
if status:
logger.warning(status)
if audio_callback_start_event.is_set():
write_queue.put_nowait(indata[:, channels_index])
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
# Create the audio stream
# InputStream must be instantiated in the process as it is not pickable.
stream = sounddevice_sdk.InputStream(
device=microphone_index,
samplerate=sample_rate,
channels=max(channels),
dtype="float32",
blocksize=0, # Varying input buffer length, but no additional latency
latency="low", # Low latency mode (not enabled by default !)
# never_drop_input=True, # Disabled as it generates an error for some devices
callback=audio_callback,
)
process_init_event.set()
while True:
start_flag = record_start_event.wait(timeout=0.1)
if record_close_event.is_set():
break
elif not start_flag:
continue
stream.start()
record_is_started_event.set()
record_stop_event.wait()
stream.stop() # stream.stop() waits for all buffers to be processed, stream.abort() flushes the buffers !
record_is_started_event.clear()
stream.close()
def start_recording(
self,
output_file: str | None = None,
multiprocessing: bool | None = False,
overwrite: bool | None = True,
barrier: Barrier | None = None,
) -> None:
"""
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if self.is_recording:
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
# Reset queue and shared memory
self.read_shared_array.reset()
self._clear_queue(self.write_queue)
# Reset stop event
self.record_stop_event.clear()
# Write recordings into a file if output_file is provided
if output_file is not None:
output_file = Path(output_file)
output_file.parent.mkdir(parents=True, exist_ok=True)
if output_file.exists():
if overwrite:
output_file.unlink()
else:
raise FileExistsError(
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
)
if multiprocessing:
self.write_stop_event = process_Event()
self.write_is_started_event = process_Event()
self.write_thread = Process(
target=PortAudioMicrophone._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
else:
self.write_stop_event = thread_Event()
self.write_is_started_event = thread_Event()
self.write_thread = Thread(
target=PortAudioMicrophone._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
self.write_thread.daemon = True
self.write_thread.start()
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
self.record_start_event.set() # Start the input audio stream process
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
if barrier is not None:
barrier.wait() # Wait for multiple input audio streams to be started at the same time
self.audio_callback_start_event.set()
if not self.is_recording:
raise RuntimeError(f"Error starting recording for microphone {self.microphone_index}.")
if output_file is not None and not self.is_writing:
raise RuntimeError(f"Error starting writing for microphone {self.microphone_index}.")
def stop_recording(self) -> None:
"""
Stops the recording of the microphones.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
if not self.is_recording:
raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
self.audio_callback_start_event.clear()
self.record_start_event.clear() # Ensures the audio stream is not started again !
self.record_stop_event.set()
# Wait for the stream to be stopped (might lead to race condition if the stream is not properly stopped on array reset and queue clearing)
timeout = 1.0
while self.is_recording and timeout > 0:
time.sleep(0.01)
timeout -= 0.01
self.read_shared_array.reset()
self._clear_queue(self.write_queue, join_queue=True)
if self.is_writing:
self.write_stop_event.set()
self.write_thread.join()
if self.is_recording:
raise RuntimeError(f"Error stopping recording for microphone {self.microphone_index}.")
if self.is_writing:
raise RuntimeError(f"Error stopping writing for microphone {self.microphone_index}.")
@staticmethod
def _write_loop(
queue,
write_stop_event: Event,
write_is_started_event: Event,
sample_rate: int,
channels: list[int],
output_file: Path,
) -> None:
"""
Thread/Process-safe loop to write audio data into a file.
"""
# Can only be run on a single process/thread for file writing safety
with SoundFile(
output_file,
mode="w",
samplerate=sample_rate,
channels=len(channels),
format="WAV",
subtype="FLOAT", # By default, a much lower quality WAV file is created !
) as file:
write_is_started_event.set()
while not write_stop_event.is_set():
try:
file.write(
queue.get(timeout=0.005)
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
queue.task_done()
except Empty:
continue
write_is_started_event.clear()
def __del__(self) -> None:
if self.is_connected:
self.disconnect()
@staticmethod
def _clear_queue(queue, join_queue: bool = False):
"""
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
"""
try:
while True:
queue.get_nowait()
queue.task_done()
except Empty:
if join_queue:
queue.join()
return

View File

@@ -1,16 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_touchlab import TouchLabSensorConfig
from .sensor_touchlab import TouchLabSensor

View File

@@ -1,42 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import MicrophoneConfig
@MicrophoneConfig.register_subclass("touchlab")
@dataclass
class TouchLabSensorConfig(MicrophoneConfig):
"""Configuration class for TouchLab tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
This class provides configuration options for TouchLab tactile sensors, including serial port, sample rate and channels.
Example configurations:
```python
# Basic configurations
TouchLabSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
TouchLabSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
```
Attributes:
sensor_port: Serial port of the tactile sensor.
baud_rate: Baud rate of the tactile sensor.
sample_rate: Sample rate in Hz for the tactile sensor.
channels: List of channel numbers to use for the tactile sensor.
"""
sensor_port: str
baud_rate: int = 115_200

View File

@@ -1,469 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the TouchLabSensor class for capturing tactile data from TouchLab tactile sensors.
"""
import logging
import time
from multiprocessing import (
Event as process_Event,
JoinableQueue as process_Queue,
Process,
)
from pathlib import Path
from queue import Empty
from threading import Barrier, Event, Event as thread_Event, Thread
from typing import Any
import numpy as np
from serial import Serial
from soundfile import SoundFile
from lerobot.utils.errors import (
DeviceAlreadyConnectedError,
DeviceAlreadyRecordingError,
DeviceNotConnectedError,
DeviceNotRecordingError,
)
from lerobot.utils.shared_array import SharedArray
from ..microphone import Microphone
from .configuration_touchlab import TouchLabSensorConfig
logger = logging.getLogger(__name__)
MAX_SERIAL_READ_SIZE = 512
class TouchLabSensor(Microphone):
"""
The TouchLabSensor class handles all TouchLab tactile sensors.
A TouchLabSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
Example of usage:
```python
from lerobot.common.robot_devices.microphones.configs import TouchLabSensorConfig
config = TouchLabSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
microphone = TouchLabSensor(config)
microphone.connect()
microphone.start_recording("some/output/file.wav")
...
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
...
microphone.stop_recording()
microphone.disconnect()
```
"""
def __init__(self, config: TouchLabSensorConfig):
""" "
Initializes the TouchLabSensor instance.
Args:
config: The configuration settings for the sensor.
"""
super().__init__(config)
# Sensor port
self.sensor_port = config.sensor_port
# Baud rate
self.baud_rate = config.baud_rate
# Input audio recording process and events
self.record_process = None
self.record_stop_event = process_Event()
self.record_start_event = process_Event()
self.record_close_event = process_Event()
self.record_is_started_event = process_Event()
self.audio_callback_start_event = process_Event()
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
self.write_queue = process_Queue()
# SharedArray to store audio from the recording process.
self.read_shared_array = None
self.local_read_shared_array = None
# Thread/Process to handle data writing in a separate thread/process (safely)
self.write_thread = None
self.write_stop_event = None
self.write_is_started_event = None
self.logs = {}
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.sensor_port})"
@property
def is_connected(self) -> bool:
"""Check if the sensor is currently connected.
Returns:
bool: True if the sensor is connected and ready to start recording,
False otherwise.
"""
return self.record_process is not None and self.record_process.is_alive()
@property
def is_recording(self) -> bool:
"""Check if the sensor is currently recording.
Returns:
bool: True if the sensor is recording, False otherwise.
"""
return self.record_is_started_event.is_set()
@property
def is_writing(self) -> bool:
"""Check if the sensor is currently writing to a file.
Returns:
bool: True if the sensor is writing to a file, False otherwise.
"""
return self.write_thread is not None and self.write_is_started_event.is_set()
@staticmethod
def find_microphones() -> list[dict[str, Any]]:
"""Detects available sensors connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected sensor.
"""
pass
def connect(self) -> None:
"""
Establish connection to the sensor.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
# Create or reset queue and shared array
self.read_shared_array = SharedArray(
shape=(self.sample_rate * 10, len(self.channels)),
dtype=np.dtype("int16"),
)
self.local_read_shared_array = self.read_shared_array.get_local_array()
self.write_queue = process_Queue()
# Reset events
self.record_start_event.clear()
self.record_stop_event.clear()
self.record_close_event.clear()
self.record_is_started_event.clear()
self.audio_callback_start_event.clear()
# Create and start an audio input stream with a recording callback
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
process_init_event = process_Event()
self.record_process = Process(
target=self._record_process,
args=(
self.sensor_port,
self.baud_rate,
self.channels,
process_init_event,
self.record_start_event,
self.record_stop_event,
self.record_close_event,
self.record_is_started_event,
self.audio_callback_start_event,
self.write_queue,
self.read_shared_array,
),
)
self.record_process.daemon = True
self.record_process.start()
is_init = process_init_event.wait(
timeout=5.0
) # Wait for the recording process to be started, and to potentially raise an error on failure.
if not self.is_connected or not is_init:
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
logger.info(f"{self} connected.")
@staticmethod
def _record_process(
sensor_port,
baud_rate,
channels,
process_init_event,
record_start_event,
record_stop_event,
record_close_event,
record_is_started_event,
audio_callback_start_event,
write_queue,
read_shared_array,
) -> None:
channels_index = np.array(channels) - 1
local_read_shared_array = read_shared_array.get_local_array()
def tactile_callback(serial_connection):
"""
Parse the tactile data from the raw input data.
"""
buffer = serial_connection.readline()
if audio_callback_start_event.is_set():
strings = buffer.decode("utf8").split(",")
num_taxels = len(strings)
if num_taxels > 0 and num_taxels < MAX_SERIAL_READ_SIZE: # Make sure we didn't read rubbish
indata = np.empty((1, num_taxels))
for i in range(num_taxels):
indata[0, i] = int(strings[i])
write_queue.put_nowait(indata[:, channels_index])
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
process_init_event.set()
while True:
start_flag = record_start_event.wait(timeout=0.1)
if record_close_event.is_set():
break
elif not start_flag:
continue
with Serial(sensor_port, baud_rate, timeout=0.5) as serial_connection:
serial_connection.flush()
record_is_started_event.set()
while not record_stop_event.is_set():
tactile_callback(serial_connection)
record_is_started_event.clear()
serial_connection.close()
def disconnect(self) -> None:
"""
Disconnect the sensor and release any resources.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if self.is_recording:
self.stop_recording()
self.record_close_event.set()
self.read_shared_array.delete()
self.write_queue.close()
self.record_process.join()
if self.is_connected:
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
logger.info(f"{self} disconnected.")
def start_recording(
self,
output_file: str | Path | None = None,
multiprocessing: bool | None = False,
overwrite: bool | None = True,
barrier: Barrier | None = None,
) -> None:
"""
Start recording tactile data from the sensor.
Args:
output_file: Optional path to save the recorded tactile data.
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
overwrite: If True, overwrites existing files at output_file path.
barrier: If not None, ensures that multiple sensors start recording at the same time.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if self.is_recording:
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
# Reset queue and shared memory
self.read_shared_array.reset()
self._clear_queue(self.write_queue)
# Reset stop event
self.record_stop_event.clear()
# Write recordings into a file if output_file is provided
if output_file is not None:
output_file = Path(output_file)
output_file.parent.mkdir(parents=True, exist_ok=True)
if output_file.exists():
if overwrite:
output_file.unlink()
else:
raise FileExistsError(
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
)
if multiprocessing:
self.write_stop_event = process_Event()
self.write_is_started_event = process_Event()
self.write_thread = Process(
target=TouchLabSensor._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
else:
self.write_stop_event = thread_Event()
self.write_is_started_event = thread_Event()
self.write_thread = Thread(
target=TouchLabSensor._write_loop,
args=(
self.write_queue,
self.write_stop_event,
self.write_is_started_event,
self.sample_rate,
self.channels,
output_file,
),
)
self.write_thread.daemon = True
self.write_thread.start()
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
self.record_start_event.set() # Start the input audio stream process
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
if barrier is not None:
barrier.wait() # Wait for multiple input audio streams to be started at the same time
self.audio_callback_start_event.set()
if not self.is_recording:
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
if output_file is not None and not self.is_writing:
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
def _read(self) -> np.ndarray:
"""
Thread/Process-safe callback to read available audio data
"""
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
def read(self) -> np.ndarray:
"""Capture and return a single audio chunk from the sensor.
Returns:
np.ndarray: Captured audio chunk as a numpy array.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if not self.is_recording:
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
start_time = time.perf_counter()
tactile_readings = self._read()
# log the number of seconds it took to read the audio chunk
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
# log the utc time at which the audio chunk was received
self.logs["timestamp_utc"] = time.perf_counter()
return tactile_readings
def _read_loop(self) -> None:
"""Internal loop run by the background thread for asynchronous reading."""
def stop_recording(self) -> None:
"""Stop recording audio from the sensor."""
if not self.is_connected:
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
if not self.is_recording:
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
self.audio_callback_start_event.clear()
self.record_start_event.clear() # Ensures the audio stream is not started again !
self.record_stop_event.set()
self.read_shared_array.reset()
self._clear_queue(self.write_queue, join_queue=True)
if self.is_writing:
self.write_stop_event.set()
self.write_thread.join()
timeout = 1.0
while self.is_recording and timeout > 0:
time.sleep(0.01)
timeout -= 0.01
if self.is_recording:
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
if self.is_writing:
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
def __del__(self) -> None:
if self.is_connected:
self.disconnect()
@staticmethod
def _clear_queue(queue, join_queue: bool = False):
"""
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
"""
try:
while True:
queue.get_nowait()
queue.task_done()
except Empty:
if join_queue:
queue.join()
return
@staticmethod
def _write_loop(
queue,
write_stop_event: Event,
write_is_started_event: Event,
sample_rate: int,
channels: list[int],
output_file: Path,
) -> None:
"""
Thread/Process-safe loop to write audio data into a file.
"""
# Can only be run on a single process/thread for file writing safety
with SoundFile(
output_file,
mode="w",
samplerate=sample_rate,
channels=len(channels),
format="WAV",
subtype="PCM_16", # Subtype for int16 values
) as file:
write_is_started_event.set()
while not write_stop_event.is_set():
try:
file.write(
queue.get(timeout=0.005)
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
queue.task_done()
except Empty:
continue
write_is_started_event.clear()

View File

@@ -1,89 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing import Barrier
from threading import Thread
from .configs import MicrophoneConfig
from .microphone import Microphone
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfig]) -> dict[str, Microphone]:
microphones = {}
for key, cfg in microphone_configs.items():
if cfg.type == "portaudio":
from .portaudio import PortAudioMicrophone
microphones[key] = PortAudioMicrophone(cfg)
elif cfg.type == "touchlab":
from .touchlab import TouchLabSensor
microphones[key] = TouchLabSensor(cfg)
else:
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
return microphones
def async_microphones_start_recording(
microphones: dict[str, Microphone],
output_files: list[str | None] | None = None,
multiprocessing: bool = False,
overwrite: bool = True,
) -> None:
"""
Starts recording on multiple microphones asynchronously to avoid delays.
Args:
microphones: A dictionary of microphones.
output_files: A list of output files.
multiprocessing: If True, enables multiprocessing for recording.
overwrite: If True, overwrites existing files at output_file path.
"""
start_recording_threads = []
if output_files is None:
output_files = [None] * len(microphones)
barrier = Barrier(len(microphones))
for microphone, output_file in zip(microphones.values(), output_files, strict=False):
start_recording_threads.append(
Thread(target=microphone.start_recording, args=(output_file, multiprocessing, overwrite, barrier))
)
for thread in start_recording_threads:
thread.start()
for thread in start_recording_threads:
thread.join()
def async_microphones_stop_recording(microphones: dict[str, Microphone]) -> None:
"""
Stops recording on multiple microphones asynchronously to avoid delays.
Args:
microphones: A dictionary of microphones.
"""
stop_recording_threads = []
for microphone in microphones.values():
stop_recording_threads.append(Thread(target=microphone.stop_recording))
for thread in stop_recording_threads:
thread.start()
for thread in stop_recording_threads:
thread.join()

View File

@@ -89,7 +89,6 @@ class ACTConfig(PreTrainedConfig):
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"AUDIO": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
@@ -100,10 +99,6 @@ class ACTConfig(PreTrainedConfig):
vision_backbone: str = "resnet18"
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
replace_final_stride_with_dilation: int = False
# Audio backbone.
audio_backbone: str = vision_backbone
pretrained_backbone_weights_audio: str | None = None
replace_final_stride_with_dilation_audio: int = False
# Transformer layers.
pre_norm: bool = False
dim_model: int = 512
@@ -166,10 +161,8 @@ class ACTConfig(PreTrainedConfig):
return None
def validate_features(self) -> None:
if not (self.image_features or self.audio_features) and not self.env_state_feature:
raise ValueError(
"You must provide at least one image/audio or the environment state among the inputs."
)
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:

View File

@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class ACTPolicy(PreTrainedPolicy):
@@ -106,8 +106,6 @@ class ACTPolicy(PreTrainedPolicy):
"""
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
if self.config.temporal_ensemble_coeff is not None:
actions = self.predict_action_chunk(batch)
action = self.temporal_ensembler.update(actions)
@@ -333,26 +331,12 @@ class ACT(nn.Module):
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Backbone for audio feature extraction.
if self.config.audio_features:
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
weights=config.pretrained_backbone_weights_audio,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.audio_backbone = IntermediateLayerGetter(
audio_backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
self.decoder = ACTDecoder(config)
# Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels), (audio_feature)].
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear(
self.config.robot_state_feature.shape[0], config.dim_model
@@ -366,10 +350,6 @@ class ACT(nn.Module):
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
if self.config.audio_features:
self.encoder_audio_feat_input_proj = nn.Conv2d(
audio_backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
n_1d_tokens = 1 # for the latent
if self.config.robot_state_feature:
@@ -379,8 +359,6 @@ class ACT(nn.Module):
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
if self.config.audio_features:
self.encoder_audio_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
@@ -505,21 +483,6 @@ class ACT(nn.Module):
encoder_in_tokens.extend(list(cam_features))
encoder_in_pos_embed.extend(list(cam_pos_embed))
if self.config.audio_features:
for audio in batch[OBS_AUDIO]:
audio_features = self.audio_backbone(audio)["feature_map"]
audio_pos_embed = self.encoder_audio_feat_pos_embed(audio_features).to(
dtype=audio_features.dtype
)
audio_features = self.encoder_audio_feat_input_proj(audio_features)
# Rearrange features to (sequence, batch, dim).
audio_features = einops.rearrange(audio_features, "b c h w -> (h w) b c")
audio_pos_embed = einops.rearrange(audio_pos_embed, "b c h w -> (h w) b c")
encoder_in_tokens.extend(list(audio_features))
encoder_in_pos_embed.extend(list(audio_pos_embed))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)

View File

@@ -17,11 +17,9 @@ from typing import Any
import torch
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
AudioProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -65,15 +63,6 @@ def make_act_pre_post_processors(
stats=dataset_stats,
device=config.device,
),
AudioProcessorStep(
output_height=224,
output_width=224,
output_channels=3,
input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION,
input_sample_rate=48000,
intermediate_sample_rate=16000,
n_fft=1024,
),
]
output_steps = [
UnnormalizerProcessorStep(

View File

@@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
import draccus
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
@dataclass(kw_only=True)
class MicrophoneConfig(draccus.ChoiceRegistry, abc.ABC):
sample_rate: int | None = None
channels: list[int] | None = None
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
__all__ = [
"ActionInterpolator",
"ActionQueue",
"LatencyTracker",
"RTCConfig",
"RTCProcessor",
]

View File

@@ -0,0 +1,116 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)

View File

@@ -79,6 +79,13 @@ class ActionQueue:
self.last_index += 1
return action.clone()
def clear(self) -> None:
"""Clear queued actions and reset consumption index."""
with self.lock:
self.queue = None
self.original_queue = None
self.last_index = 0
def qsize(self) -> int:
"""Get the number of remaining actions in the queue.
@@ -123,14 +130,26 @@ class ActionQueue:
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
return self.original_queue[self.last_index :].clone()
def get_processed_left_over(self) -> Tensor | None:
"""Get leftover processed actions (the actions currently executed by the robot).
Returns:
Tensor | None: Remaining processed actions (remaining_steps, action_dim),
or None if no processed queue exists.
"""
with self.lock:
if self.queue is None:
return None
return self.queue[self.last_index :].clone()
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
action_index_before_inference: int | None = None,
):
"""Merge new actions into the queue.
@@ -145,10 +164,10 @@ class ActionQueue:
action_index_before_inference: Index before inference started, for validation.
"""
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
delay = self._check_and_resolve_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
self._replace_actions_queue(original_actions, processed_actions, delay)
return
self._append_actions_queue(original_actions, processed_actions)
@@ -164,12 +183,13 @@ class ActionQueue:
processed_actions: Post-processed actions for robot.
real_delay: Number of time steps to skip due to inference delay.
"""
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
clamped_delay = max(0, min(real_delay, len(original_actions), len(processed_actions)))
self.original_queue = original_actions[clamped_delay:].clone()
self.queue = processed_actions[clamped_delay:].clone()
logger.debug(f"original_actions shape: {self.original_queue.shape}")
logger.debug(f"processed_actions shape: {self.queue.shape}")
logger.debug(f"real_delay: {real_delay}")
logger.debug(f"real_delay: {real_delay}, clamped_delay: {clamped_delay}")
self.last_index = 0
@@ -196,7 +216,9 @@ class ActionQueue:
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
def _check_and_resolve_delays(
self, real_delay: int, action_index_before_inference: int | None = None
) -> int:
"""Validate that computed delays match expectations.
Compares the delay computed from inference latency with the actual
@@ -205,15 +227,20 @@ class ActionQueue:
Args:
real_delay: Delay computed from inference latency.
action_index_before_inference: Action index when inference started.
"""
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as delay calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
)
Returns:
int: Delay to use.
"""
effective_delay = max(0, real_delay)
if action_index_before_inference is not None:
indexes_diff = max(0, self.last_index - action_index_before_inference)
if indexes_diff != real_delay:
logger.warning(
"Indexes diff is not equal to real delay. indexes_diff=%d, real_delay=%d",
indexes_diff,
real_delay,
)
return real_delay
return effective_delay

View File

@@ -106,7 +106,7 @@ def prepare_observation_for_inference(
This function takes a dictionary of NumPy arrays, performs necessary
preprocessing, and prepares it for model inference. The steps include:
1. Converting NumPy arrays to PyTorch tensors.
2. Normalizing and permuting image data and audio data (if any).
2. Normalizing and permuting image data (if any).
3. Adding a batch dimension to each tensor.
4. Moving all tensors to the specified compute device.
5. Adding task and robot type information to the dictionary.
@@ -129,9 +129,6 @@ def prepare_observation_for_inference(
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
elif "audio" in name:
observation[name] = observation[name].type(torch.float32)
observation[name] = observation[name].permute(1, 0).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)

View File

@@ -23,7 +23,6 @@ from lerobot.types import (
TransitionKey,
)
from .audio_processor import AudioProcessorStep
from .batch_processor import AddBatchDimensionProcessorStep
from .converters import (
batch_to_transition,
@@ -89,7 +88,6 @@ __all__ = [
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",
"AddTeleopEventsAsInfoStep",
"AudioProcessorStep",
"ComplementaryDataProcessorStep",
"batch_to_transition",
"create_transition",

View File

@@ -1,130 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from torch import Tensor
from torchaudio.functional import amplitude_to_DB
from torchaudio.transforms import MelSpectrogram, Resample
from torchvision.transforms import Compose, Lambda, Resize
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
from lerobot.utils.constants import OBS_AUDIO
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="audio_processor")
class AudioProcessorStep(ObservationProcessorStep):
"""
Processes audio waveform data into a mel-spectrogram image representation.
**Audio Processing:**
- Averages waveform data over all channels.
- Resamples the waveform to 16kHz.
- Converts the waveform to a mel-spectrogram.
- Converts the mel-spectrogram to decibels.
- Resizes the mel-spectrogram to 224×224.
- Converts the mel-spectrogram to a channel-first, normalized tensor.
Attributes:
output_height: Height of the output mel-spectrogram image in pixels.
output_width: Width of the output mel-spectrogram image in pixels.
output_channels: Number of channels in the output image (3 for RGB-like format).
input_audio_chunk_duration: Duration of the input audio chunk in seconds.
input_sample_rate: Original sample rate of the input audio in Hz.
intermediate_sample_rate: Reduced intermediate sample rate in Hz.
Downsampling improves the temporal resolution but reduces the frequency range.
n_fft: Size of the FFT window for spectrogram computation.
Increasing the window size increases the frequency resolution but decreases the temporal resolution.
hop_length: Number of samples between successive frames, computed automatically to match the output_width.
Decreasing the hop length increases the temporal resolution but decreases the frequency resolution.
n_mels: Number of mel filter banks, computed automatically to match the output_height.
Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution.
mel_spectrogram_transform: The complete audio processing pipeline.
"""
output_height: int = 224
output_width: int = 224
output_channels: int = 3
input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION
input_sample_rate: int = 48000
intermediate_sample_rate: int = 16000
n_fft: int = 1024
# Parameters computed from other parameters at initialization
hop_length: int = field(init=False)
n_mels: int = field(init=False)
mel_spectrogram_transform: Compose = field(init=False, repr=False)
def __post_init__(self):
self.hop_length = int(
self.intermediate_sample_rate * self.input_audio_chunk_duration
- self.n_fft // self.output_width
- 1
)
self.n_mels = self.output_height
self.mel_spectrogram_transform = Compose(
[
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate),
MelSpectrogram(
sample_rate=self.intermediate_sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
power=2, # Power spectrum
),
Lambda(
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
), # Convert to decibels
Resize(
(self.output_height, self.output_width)
), # Resize spectrogram to output_height×output_width
Lambda(
lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1)
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
]
)
def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
"""
Processes audio data contained in the provided observation.
"""
processed_obs = observation.copy()
# Process single audio observation
if OBS_AUDIO in processed_obs:
audio_data = processed_obs[OBS_AUDIO]
if isinstance(audio_data, Tensor) and audio_data.dim() == 3: # Batch, Channels, Samples
processed_obs[OBS_AUDIO] = self.mel_spectrogram_transform(audio_data)
# Process multiple audio observations
for key, value in processed_obs.items():
if (
key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 3
): # Batch, Channels, Samples
processed_obs[key] = self.mel_spectrogram_transform(value)
return processed_obs
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
return self._process_observation(observation)

View File

@@ -25,7 +25,8 @@ from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.types import EnvTransition, PolicyAction
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from .pipeline import (
ComplementaryDataProcessorStep,
@@ -35,7 +36,6 @@ from .pipeline import (
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.types import PolicyAction, EnvTransition
@dataclass
@@ -88,8 +88,6 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
- State vectors (1D tensors).
- Single images (3D tensors).
- Dictionaries of multiple images (3D tensors).
- Single audio waveforms (2D tensors).
- Dictionaries of multiple audio waveforms (2D tensors).
"""
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
@@ -119,18 +117,6 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
# Process single audio observation - add batch dim if 2D
if OBS_AUDIO in observation:
audio_value = observation[OBS_AUDIO]
if isinstance(audio_value, Tensor) and audio_value.dim() == 2:
observation[OBS_AUDIO] = audio_value.unsqueeze(0)
# Process multiple audio observations - add batch dim if 2D
for key, value in observation.items():
if key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 2:
observation[key] = value.unsqueeze(0)
return observation
def transform_features(

View File

@@ -96,9 +96,11 @@ class BiOpenArmFollower(Robot):
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return {
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
}
@property
@@ -150,14 +152,16 @@ class BiOpenArmFollower(Robot):
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
right_obs = self.right_arm.get_observation()
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict
@check_if_not_connected
@@ -183,7 +187,7 @@ class BiOpenArmFollower(Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):

View File

@@ -23,10 +23,12 @@ from ..config import RobotConfig
@RobotConfig.register_subclass("bi_openarm_follower")
@dataclass
@dataclass(kw_only=True)
class BiOpenArmFollowerConfig(RobotConfig):
"""Configuration class for Bi OpenArm Follower robots."""
id: str | None = "bi_openarm_follower"
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase

View File

@@ -34,13 +34,6 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
raise ValueError(
f"Specifying '{attr}' is required for the camera to be used in a robot"
)
if hasattr(self, "microphones") and self.microphones:
for _, config in self.microphones.items():
for attr in ["sample_rate", "channels"]:
if getattr(config, attr) is None:
raise ValueError(
f"Specifying '{attr}' is required for the microphone to be used in a robot"
)
@property
def type(self) -> str:

View File

@@ -15,7 +15,6 @@
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.microphones import MicrophoneConfig
from ..config import RobotConfig
@@ -36,8 +35,5 @@ class KochFollowerConfig(RobotConfig):
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# microphones
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False

View File

@@ -19,7 +19,6 @@ import time
from functools import cached_property
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.microphones.utils import make_microphones_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
@@ -62,7 +61,6 @@ class KochFollower(Robot):
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
self.microphones = make_microphones_from_configs(config.microphones)
@property
def _motors_ft(self) -> dict[str, type]:
@@ -74,16 +72,9 @@ class KochFollower(Robot):
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@property
def _microphones_ft(self) -> dict[str, tuple]:
return {
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
for mic in self.microphones
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
@@ -91,11 +82,7 @@ class KochFollower(Robot):
@property
def is_connected(self) -> bool:
return (
self.bus.is_connected
and all(cam.is_connected for cam in self.cameras.values())
and all(mic.is_connected for mic in self.microphones.values())
)
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
@@ -114,9 +101,6 @@ class KochFollower(Robot):
for cam in self.cameras.values():
cam.connect()
for mic in self.microphones.values():
mic.connect()
self.configure()
logger.info(f"{self} connected.")
@@ -213,13 +197,6 @@ class KochFollower(Robot):
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
# Read audio frames from microphones
for mic_key, mic in self.microphones.items():
start = time.perf_counter()
obs_dict[mic_key] = mic.read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
return obs_dict
@check_if_not_connected
@@ -255,7 +232,5 @@ class KochFollower(Robot):
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
for mic in self.microphones.values():
mic.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -16,7 +16,6 @@ from dataclasses import dataclass, field
from lerobot.cameras.configs import CameraConfig, Cv2Rotation
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.microphones import MicrophoneConfig
from ..config import RobotConfig
@@ -46,8 +45,6 @@ class LeKiwiConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
@@ -95,7 +92,5 @@ class LeKiwiClientConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
polling_timeout_ms: int = 15
connect_timeout_s: int = 5

View File

@@ -23,7 +23,6 @@ from typing import Any
import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.microphones.utils import make_microphones_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
@@ -74,7 +73,6 @@ class LeKiwi(Robot):
self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")]
self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")]
self.cameras = make_cameras_from_configs(config.cameras)
self.microphones = make_microphones_from_configs(config.microphones)
@property
def _state_ft(self) -> dict[str, type]:
@@ -99,16 +97,9 @@ class LeKiwi(Robot):
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@property
def _microphones_ft(self) -> dict[str, tuple]:
return {
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
for mic in self.microphones
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
return {**self._state_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
@@ -116,11 +107,7 @@ class LeKiwi(Robot):
@property
def is_connected(self) -> bool:
return (
self.bus.is_connected
and all(cam.is_connected for cam in self.cameras.values())
and all(mic.is_connected for mic in self.microphones.values())
)
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
@@ -134,9 +121,6 @@ class LeKiwi(Robot):
for cam in self.cameras.values():
cam.connect()
for mic in self.microphones.values():
mic.connect()
self.configure()
logger.info(f"{self} connected.")
@@ -380,13 +364,6 @@ class LeKiwi(Robot):
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
# Read audio frames from microphones
for mic_key, mic in self.microphones.items():
start = time.perf_counter()
obs_dict[mic_key] = mic.read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
return obs_dict
@check_if_not_connected
@@ -436,7 +413,5 @@ class LeKiwi(Robot):
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
for mic in self.microphones.values():
mic.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -18,7 +18,6 @@ import base64
import json
import logging
from functools import cached_property
from time import perf_counter
import cv2
import numpy as np
@@ -59,9 +58,8 @@ class LeKiwiClient(Robot):
self.zmq_observation_socket = None
self.last_frames = {}
self.last_remote_state = {}
self.last_frame_timestamp = None
self.last_frame_delay = 0.0
# Define three speed levels and a current index
self.speed_levels = [
@@ -99,13 +97,9 @@ class LeKiwiClient(Robot):
def _cameras_ft(self) -> dict[str, tuple[int, int, int]]:
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
@cached_property
def _microphones_ft(self) -> dict[str, tuple]:
return {name: (cfg.sample_rate, cfg.channels) for name, cfg in self.config.microphones.items()}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
return {**self._state_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
@@ -141,7 +135,6 @@ class LeKiwiClient(Robot):
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
self.last_frame_timestamp = perf_counter()
self._is_connected = True
def calibrate(self) -> None:
@@ -174,8 +167,6 @@ class LeKiwiClient(Robot):
if last_msg is None:
logging.warning("Poller indicated data, but failed to retrieve message.")
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
self.last_frame_timestamp = perf_counter()
return last_msg
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
@@ -212,16 +203,14 @@ class LeKiwiClient(Robot):
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
# Decode images and audio data
# Decode images
current_frames: dict[str, np.ndarray] = {}
for frame_name, frame_data in observation.items():
if frame_name in self._cameras_ft:
image = self._decode_image_from_b64(frame_data)
if image is not None:
current_frames[frame_name] = image
elif frame_name in self._microphones_ft:
if frame_data is not None:
current_frames[frame_name] = frame_data
for cam_name, image_b64 in observation.items():
if cam_name not in self._cameras_ft:
continue
frame = self._decode_image_from_b64(image_b64)
if frame is not None:
current_frames[cam_name] = frame
return current_frames, obs_dict
@@ -265,27 +254,17 @@ class LeKiwiClient(Robot):
"""
Capture observations from the remote robot: current follower arm positions,
present wheel speeds (converted to body-frame velocities: x, y, theta),
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
and a camera frame. Receives over ZMQ, translate to body-frame vel
"""
frames, obs_dict = self._get_data()
# Loop over each configured camera and microphone
for frame_name, frame_data in frames.items():
if frame_data is None:
if frame_name in self._cameras_ft:
logging.warning("Image frame is None")
image = np.zeros((640, 480, 3), dtype=np.uint8)
obs_dict[frame_name] = image
elif frame_name in self._microphones_ft:
logging.warning("Audio frame is None")
obs_dict[frame_name] = np.zeros(
(
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
self._microphones_ft[frame_name][1],
),
dtype=np.float32,
)
# Loop over each configured camera
for cam_name, frame in frames.items():
if frame is None:
logging.warning("Frame is None")
frame = np.zeros((640, 480, 3), dtype=np.uint8)
obs_dict[cam_name] = frame
return obs_dict

View File

@@ -17,7 +17,6 @@
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.microphones import MicrophoneConfig
from ..config import RobotConfig
@@ -39,9 +38,6 @@ class SOFollowerConfig:
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# microphones
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = True

View File

@@ -19,7 +19,6 @@ import time
from functools import cached_property
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.microphones.utils import make_microphones_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
@@ -62,7 +61,6 @@ class SOFollower(Robot):
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
self.microphones = make_microphones_from_configs(config.microphones)
@property
def _motors_ft(self) -> dict[str, type]:
@@ -74,16 +72,9 @@ class SOFollower(Robot):
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@property
def _microphones_ft(self) -> dict[str, tuple]:
return {
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
for mic in self.microphones
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
@@ -91,11 +82,7 @@ class SOFollower(Robot):
@property
def is_connected(self) -> bool:
return (
self.bus.is_connected
and all(cam.is_connected for cam in self.cameras.values())
and all(mic.is_connected for mic in self.microphones.values())
)
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
@@ -114,9 +101,6 @@ class SOFollower(Robot):
for cam in self.cameras.values():
cam.connect()
for mic in self.microphones.values():
mic.connect()
self.configure()
logger.info(f"{self} connected.")
@@ -206,13 +190,6 @@ class SOFollower(Robot):
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
# Read audio frames from microphones
for mic_key, mic in self.microphones.items():
start = time.perf_counter()
obs_dict[mic_key] = mic.read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
return obs_dict
@check_if_not_connected
@@ -248,8 +225,6 @@ class SOFollower(Robot):
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
for mic in self.microphones.values():
mic.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -85,7 +85,7 @@ from lerobot.datasets.utils import (
flatten_dict,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
@@ -318,12 +318,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Check if adding this episode would exceed the limit
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
# Size limit would be exceeded, save current accumulation WITHOUT this episode
concatenate_media_files(
concatenate_video_files(
paths_to_cat,
new_root
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
@@ -359,7 +359,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
# Write remaining videos if any
if paths_to_cat:
concatenate_media_files(
concatenate_video_files(
paths_to_cat,
new_root
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
@@ -402,12 +402,7 @@ def generate_episode_metadata_dict(
if len(ep_ids_set) != 1:
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
ep_dict = {
**ep_metadata,
**ep_video,
**ep_legacy_metadata,
**flatten_dict({"stats": ep_stats}),
}
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
ep_dict["meta/episodes/chunk_index"] = 0
ep_dict["meta/episodes/file_index"] = 0
yield ep_dict
@@ -428,10 +423,7 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
ds_episodes = Dataset.from_generator(
lambda: generate_episode_metadata_dict(
episodes_legacy_metadata,
episodes_metadata,
episodes_stats,
episodes_video_metadata,
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
)
)
write_episodes(ds_episodes, new_root)

View File

@@ -33,8 +33,6 @@ import draccus
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,

View File

@@ -47,6 +47,7 @@ You can learn about the CLI options for this script in the `EvalPipelineConfig`
"""
import concurrent.futures as cf
import copy
import json
import logging
import threading
@@ -56,7 +57,6 @@ from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from functools import partial
from pathlib import Path
from pprint import pformat
from typing import Any, TypedDict
@@ -73,7 +73,6 @@ from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
close_envs,
preprocess_observation,
@@ -166,9 +165,9 @@ def rollout(
if return_observations:
all_observations.append(deepcopy(observation))
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# Infer "task" from sub-environments.
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
observation["task"] = env.call("task")
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
@@ -734,34 +733,48 @@ def eval_policy_all(
group_acc[group]["video_paths"].extend(paths)
overall["video_paths"].extend(paths)
def _make_thread_policy(p: PreTrainedPolicy) -> PreTrainedPolicy:
"""Shallow copy sharing weight tensors, with independent per-thread state.
copy.copy() gives a new Python object whose _parameters dict is a shared
reference (same tensor storage, zero extra VRAM). reset() then rebinds
mutable state (action queues etc.) to fresh per-thread objects.
Note: does NOT work for ACT with temporal_ensemble_coeff — that policy's
reset() mutates a shared sub-object. Use max_parallel_tasks=1 for that config.
"""
thread_p = copy.copy(p)
thread_p.reset()
return thread_p
# Choose runner (sequential vs threaded)
task_runner = partial(
run_one,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
)
_runner_kwargs = {
"env_preprocessor": env_preprocessor,
"env_postprocessor": env_postprocessor,
"preprocessor": preprocessor,
"postprocessor": postprocessor,
"n_episodes": n_episodes,
"max_episodes_rendered": max_episodes_rendered,
"videos_dir": videos_dir,
"return_episode_data": return_episode_data,
"start_seed": start_seed,
}
if max_parallel_tasks <= 1:
# sequential path (single accumulator path on the main thread)
# NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
for task_group, task_id, env in tasks:
tg, tid, metrics = task_runner(task_group, task_id, env)
tg, tid, metrics = run_one(task_group, task_id, env, policy=policy, **_runner_kwargs)
_accumulate_to(tg, metrics)
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
else:
# threaded path: submit all tasks, consume completions on main thread and accumulate there
# threaded path: each thread gets a shallow policy copy (shared weights, independent state)
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
fut2meta = {}
for task_group, task_id, env in tasks:
fut = executor.submit(task_runner, task_group, task_id, env)
fut = executor.submit(
run_one, task_group, task_id, env, policy=_make_thread_policy(policy), **_runner_kwargs
)
fut2meta[fut] = (task_group, task_id)
for fut in cf.as_completed(fut2meta):
tg, tid, metrics = fut.result()

View File

@@ -69,13 +69,12 @@ lerobot-record \
import logging
import time
from copy import copy
from dataclasses import asdict, dataclass, field
from pathlib import Path
from pprint import pformat
from typing import Any
import numpy as np
import torch
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
@@ -90,22 +89,10 @@ from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
)
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.microphones import (
MicrophoneConfig, # noqa: F401
)
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
from lerobot.microphones.utils import (
async_microphones_start_recording,
async_microphones_stop_recording,
)
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
@@ -147,7 +134,6 @@ from lerobot.teleoperators import ( # noqa: F401
unitree_g1,
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.utils.audio_utils import rolling_vstack
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import (
init_keyboard_listener,
@@ -243,6 +229,9 @@ class RecordConfig:
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
# Only applies when using a policy (not teleop)
interpolation_multiplier: int = 1
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -315,15 +304,9 @@ def record_loop(
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
interpolator: ActionInterpolator | None = None,
display_compressed_images: bool = False,
):
if display_data:
init_rerun(
session_name="recording",
robot=robot,
reset_time=True,
)
if dataset is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
@@ -358,35 +341,15 @@ def record_loop(
preprocessor.reset()
postprocessor.reset()
# Create a buffer for audio observations (shifting window of fixed size over audio samples)
if robot.microphones and (policy is not None or dataset is not None):
audio_buffer = {
microphone_name: np.zeros(
(int(microphone.sample_rate * DEFAULT_AUDIO_CHUNK_DURATION), len(microphone.channels))
)
for microphone_name, microphone in robot.microphones.items()
}
# Reset interpolator if provided
if interpolator is not None:
interpolator.reset()
if (
dataset is not None and robot.name != "lekiwi"
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
dataset.add_microphones_recordings(robot.microphones)
else:
async_microphones_start_recording(robot.microphones)
# Fill audio buffers if needed
if (
robot.microphones
and (policy is not None or dataset is not None)
and DEFAULT_INITIAL_AUDIO_BUFFER_DURATION > 0.0
):
# This initial wait might be longer than the audio chunk duration to
# (1) ensure that the audio buffers are filled with enough data
# (2) add additional initial samples to the dataset in case of variable audio chunk duration during training
precise_sleep(DEFAULT_INITIAL_AUDIO_BUFFER_DURATION)
for microphone_name, microphone in robot.microphones.items():
audio_chunk = microphone.read()
audio_buffer[microphone_name] = rolling_vstack(audio_buffer[microphone_name], audio_chunk)
# Calculate control interval based on interpolation
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
# Pre-compute action key order outside the hot loop — it won't change mid-episode.
action_keys = sorted(robot.action_features) if use_interpolation else []
no_action_count = 0
timestamp = 0
@@ -407,34 +370,67 @@ def record_loop(
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Track whether this iteration should be recorded to the dataset.
# Interpolated-only iterations send actions to the robot but don't record frames,
# keeping the dataset at the original fps while the robot moves at the higher rate.
is_record_frame = True
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
# Transform instantaneous audio samples into a buffer of fixed size
buffered_observation_frame = copy(observation_frame)
for name in audio_buffer:
# Add the audio buffer to the observation
buffered_observation_frame[name] = rolling_vstack(audio_buffer[name], observation_frame[name])
# With interpolation: only call policy when interpolator needs new action
if use_interpolation:
ran_inference = False
action_values = predict_action(
observation=buffered_observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
if interpolator.needs_new_action():
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
interpolator.add(action_tensor)
ran_inference = True
interp_action = interpolator.get()
if interp_action is not None:
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
action_values = robot_action_to_send
else:
continue
is_record_frame = ran_inference
else:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
# Applies a pipeline to the action, default is IdentityProcessor
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()
if robot.name == "unitree_g1":
teleop.send_feedback(obs)
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
elif policy is None and isinstance(teleop, list):
arm_action = teleop_arm.get_action()
@@ -443,6 +439,8 @@ def record_loop(
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
else:
no_action_count += 1
if no_action_count == 1 or no_action_count % 10 == 0:
@@ -453,37 +451,26 @@ def record_loop(
)
continue
# Applies a pipeline to the action, default is IdentityProcessor
if policy is not None and act_processed_policy is not None:
action_values = act_processed_policy
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
else:
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Write to dataset
if dataset is not None:
# Write to dataset (only on real policy frames, not interpolated-only iterations)
if dataset is not None and is_record_frame:
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
if display_data:
log_rerun_data(
observation=obs_processed,
action=action_values,
compress_images=display_compressed_images,
log_time=time.perf_counter() - start_episode_t,
observation=obs_processed, action=action_values, compress_images=display_compressed_images
)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s: float = 1 / fps - dt_s
sleep_time_s: float = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
@@ -493,8 +480,6 @@ def record_loop(
timestamp = time.perf_counter() - start_episode_t
async_microphones_stop_recording(robot.microphones)
@parser.wrap()
def record(cfg: RecordConfig) -> LeRobotDataset:
@@ -571,6 +556,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
interpolator = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
@@ -581,6 +567,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Create interpolator for smoother policy control
if cfg.interpolation_multiplier > 1:
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
robot.connect()
if teleop is not None:
@@ -612,6 +602,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
interpolator=interpolator,
display_compressed_images=display_compressed_images,
)

View File

@@ -62,8 +62,6 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # no
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
from lerobot.processor import (
RobotAction,
RobotObservation,
@@ -153,18 +151,8 @@ def teleop_loop(
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
robot_observation_processor: An optional pipeline to process raw observations from the robot.
"""
if display_data:
init_rerun(
session_name="teleoperation",
robot=robot,
reset_time=True,
)
display_len = max(len(key) for key in robot.action_features)
for _, microphone in robot.microphones.items():
microphone.start_recording()
start = time.perf_counter()
while True:
loop_start = time.perf_counter()
@@ -198,7 +186,6 @@ def teleop_loop(
observation=obs_transition,
action=teleop_action,
compress_images=display_compressed_images,
log_time=time.perf_counter() - start,
)
print("\n" + "-" * (display_len + 10))
@@ -215,10 +202,7 @@ def teleop_loop(
move_cursor_up(1)
if duration is not None and time.perf_counter() - start >= duration:
break
for _, microphone in robot.microphones.items():
microphone.stop_recording()
return
@parser.wrap()

View File

@@ -32,9 +32,15 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
# Leader joint 6 maps to follower joint 7 and vice versa
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator):
"""
@@ -95,6 +101,8 @@ class OpenArmMini(Teleoperator):
@property
def action_features(self) -> dict[str, type]:
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
@@ -276,16 +284,70 @@ class OpenArmMini(Teleoperator):
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
action: dict[str, Any] = {}
for motor, val in right_positions.items():
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def enable_torque(self) -> None:
"""Enable torque on both arms for position control."""
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None:
"""Disable torque on both arms for free movement."""
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
right_goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items():
if not key.endswith(".pos"):
continue
motor_name = key.removesuffix(".pos")
if motor_name.startswith("right_"):
base = motor_name.removeprefix("right_")
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if right_goals:
self.bus_right.sync_write("Goal_Position", right_goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")

View File

@@ -1,37 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def rolling_vstack(buffer: np.ndarray, new_data: np.ndarray) -> np.ndarray:
"""
Rolling implementation of numpy.vstack to add new data in at the end of a fixed shape buffer in a rolling fashion.
Args:
buffer: The *fixed* shape buffer to update.
new_data: The new data to add to the buffer.
Returns:
The updated buffer.
"""
buffer_size = buffer.shape[0]
# Remove as many old audio samples as needed
buffer[: -len(new_data)] = buffer[len(new_data) :]
# Add new audio samples, only the newest if the buffer is already full
buffer[-len(new_data) :] = new_data[-buffer_size:]
return buffer

View File

@@ -23,7 +23,6 @@ OBS_ENV_STATE = OBS_STR + ".environment_state"
OBS_STATE = OBS_STR + ".state"
OBS_IMAGE = OBS_STR + ".image"
OBS_IMAGES = OBS_IMAGE + "s"
OBS_AUDIO = OBS_STR + ".audio"
OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"

View File

@@ -103,7 +103,7 @@ def predict_action(
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: normalizing and permuting (channel first)
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation = preprocessor(observation)

View File

@@ -30,22 +30,3 @@ class DeviceAlreadyConnectedError(ConnectionError):
):
self.message = message
super().__init__(self.message)
class DeviceNotRecordingError(Exception):
"""Exception raised when the robot device is not recording."""
def __init__(self, message="This robot device is not recording. Try calling `start_recording()` first."):
self.message = message
super().__init__(self.message)
class DeviceAlreadyRecordingError(Exception):
"""Exception raised when the robot device is already recording."""
def __init__(
self,
message="This robot device is already recording. Try not calling `start_recording()` twice.",
):
self.message = message
super().__init__(self.message)

View File

@@ -1,105 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing import Lock, Value, shared_memory
import numpy as np
class SharedArray:
"""
A SharedArray is a numpy array shared between multiple processes in a shared_memory object.
- Data is written to the array using the `write` method, which appends data to the array.
- Data is read from the array (and eventually flushed) using the `read` method, which copies the _whole_ array.
SharedArray offers quasi-instantaneous array-wide read and flush capabilities in comparison to Queues, but has a limited size defined at initialization.
Example:
_Main_process_
shared_array = SharedArray(shape=(10, 10), dtype=np.dtype("float32"))
local_array = shared_array.get_local_array()
shared_array.write(local_array, np.array([[1, 2, 3], [4, 5, 6]]))
_Child_process_
local_array = shared_array.get_local_array()
data = shared_array.read(local_array, flush=True)
"""
def __init__(self, shape: tuple[int], dtype: np.dtype | str):
"""
Initialize a SharedArray.
Args:
shape: The shape of the shared array.
dtype: The dtype of the shared array.
"""
self.shape = shape
self.dtype = dtype
self.shared_memory = shared_memory.SharedMemory(
create=True, size=np.prod(shape) * np.dtype(dtype).itemsize
)
self.read_index = Value("i", 0)
self.lock = Lock()
def get_local_array(self) -> np.ndarray:
"""
Get a process-local instance of the shared array.
Returns:
A process-local instance of the shared array.
"""
return np.ndarray(self.shape, dtype=np.dtype(self.dtype), buffer=self.shared_memory.buf)
def delete(self):
"""
Delete the shared array.
"""
self.shared_memory.close()
self.shared_memory.unlink()
def write(self, local_array: np.ndarray, data: np.ndarray):
"""
Write data to the shared array.
Args:
local_array: The process-local instance of the shared array to write to.
data: The data to write to the shared array.
"""
with self.lock:
local_array[self.read_index.value : self.read_index.value + len(data)] = data
self.read_index.value += len(data)
def read(self, local_array: np.ndarray, flush: bool = True) -> np.ndarray:
"""
Read data from the shared array.
Args:
local_array: The process-local instance of the shared array to read from.
flush: Whether to flush the shared array after reading.
"""
with self.lock:
data = np.copy(local_array[: self.read_index.value])
if flush:
self.read_index.value = 0
return data
def reset(self):
"""
Reset the read index to 0.
"""
with self.lock:
self.read_index.value = 0

View File

@@ -14,25 +14,17 @@
import numbers
import os
import time
from uuid import uuid4
import numpy as np
import rerun as rr
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots import Robot
from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
def init_rerun(
session_name: str = "lerobot_control_loop",
ip: str | None = None,
port: int | None = None,
robot: Robot | None = None,
reset_time: bool = False,
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
) -> None:
"""
Initializes the Rerun SDK for visualizing the control loop.
@@ -41,26 +33,16 @@ def init_rerun(
session_name: Name of the Rerun session.
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
robot: A Robot object. If provided, Rerun will be initialized with a blueprint that includes the object's cameras and microphones.
reset_time: Whether to reset the timer "episode_time" to 0.
"""
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(
application_id=session_name,
recording_id=uuid4(),
)
if robot is not None:
rr.send_blueprint(build_rerun_blueprint(robot))
rr.init(session_name)
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
if ip and port:
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
else:
rr.spawn(memory_limit=memory_limit)
if reset_time:
rr.set_time("episode_time", timestamp=0.0)
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
@@ -68,47 +50,10 @@ def _is_scalar(x):
)
def build_rerun_blueprint(robot: Robot) -> rr.blueprint.Grid:
""" "
Builds a Rerun blueprint for optimized visualization of the robot's observations and actions :
- Time series views for all scalar observations and actions (e.g. position, velocity, torque, etc.).
- Spatial 2D views for all camera observations.
- Time series views for all microphone observations.
Args:
robot: A Robot object.
Returns:
A Rerun blueprint.
"""
contents = [
rr.blueprint.TimeSeriesView(
origin="data",
plot_legend=rr.blueprint.PlotLegend(visible=True),
)
]
if robot.microphones:
contents += [
rr.blueprint.TimeSeriesView(
origin="audio",
plot_legend=rr.blueprint.PlotLegend(visible=True),
)
]
if robot.cameras:
contents += [
rr.blueprint.Spatial2DView(
origin=OBS_PREFIX + camera_name,
)
for camera_name in robot.cameras
]
return rr.blueprint.Grid(*contents)
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
log_time: float | None = None,
) -> None:
"""
Logs observation and action data to Rerun for real-time visualization.
@@ -127,13 +72,7 @@ def log_rerun_data(
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
log_time: The time to log the data in the "episode_time" timeline.
If None, the current time is used in Rerun's default timeline.
"""
if log_time is None:
log_time = time.perf_counter()
rr.set_time("episode_time", timestamp=log_time)
if observation:
for k, v in observation.items():
if v is None:
@@ -141,41 +80,15 @@ def log_rerun_data(
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v):
rr.log("data/" + key, rr.Scalars(float(v)))
rr.log(key, rr.Scalars(float(v)))
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
# Convert samples x channels -> channels x samples when needed
elif arr.ndim == 2 and arr.shape[1] < arr.shape[0]:
arr = np.transpose(arr, (1, 0))
if arr.ndim == 1:
for i, vi in enumerate(arr):
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
elif arr.ndim == 2:
for i, channel_arr in enumerate(arr):
rr.send_columns(
"audio/"
+ key
+ f"_channel_{i}", # TODO(CarolinePascal): Get actual channel number/name
indexes=[
rr.TimeColumn(
"episode_time",
timestamp=log_time
+ np.linspace(
-DEFAULT_AUDIO_CHUNK_DURATION,
0,
len(channel_arr),
endpoint=False,
),
)
],
columns=rr.Scalars.columns(scalars=channel_arr),
)
elif arr.ndim == 3:
rr.log(key, rr.Image(arr), static=True)
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
@@ -187,13 +100,13 @@ def log_rerun_data(
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log("data/" + key, rr.Scalars(float(v)))
rr.log(key, rr.Scalars(float(v)))
elif isinstance(v, np.ndarray):
if v.ndim == 1:
for i, vi in enumerate(v):
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
else:
# Fall back to flattening higher-dimensional arrays
flat = v.flatten()
for i, vi in enumerate(flat):
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))

View File

@@ -57,8 +57,6 @@ def _check_component_availability(component_type, available_components, make_com
print("\nNo physical device detected.")
elif isinstance(e, ValueError) and "camera_index" in str(e):
print("\nNo physical camera detected.")
elif isinstance(e, ValueError) and "microphone_index" in str(e):
print("\nNo physical microphone detected.")
else:
traceback.print_exc()

View File

@@ -26,22 +26,16 @@ from lerobot.datasets.compute_stats import (
compute_episode_stats,
estimate_num_samples,
get_feature_stats,
sample_audio_from_data,
sample_audio_from_path,
sample_images,
sample_indices,
)
from lerobot.utils.constants import OBS_AUDIO, OBS_IMAGE, OBS_STATE
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
def mock_load_audio(path):
return np.ones((16000, 2), dtype=np.float32)
@pytest.fixture
def sample_array():
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -79,25 +73,6 @@ def test_sample_images(mock_load):
assert len(images) == estimate_num_samples(100)
@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
def test_sample_audio_from_path(mock_load):
audio_path = "audio.wav"
audio_samples = sample_audio_from_path(audio_path)
assert isinstance(audio_samples, np.ndarray)
assert audio_samples.shape[1] == 2
assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000)
def test_sample_audio_from_data():
audio_data = np.ones((16000, 2), dtype=np.float32)
audio_samples = sample_audio_from_data(audio_data)
assert isinstance(audio_samples, np.ndarray)
assert audio_samples.shape[1] == 2
assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000)
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
@@ -106,14 +81,6 @@ def test_get_feature_stats_images():
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_audio():
data = np.random.uniform(-1, 1, (16000, 2))
stats = get_feature_stats(data, axis=0, keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
np.testing.assert_equal(stats["count"], np.array([16000]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = {
"min": np.array([[1, 2, 3]]),
@@ -178,27 +145,20 @@ def test_get_feature_stats_single_value():
def test_compute_episode_stats():
episode_data = {
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
OBS_AUDIO: "audio.wav",
OBS_STATE: np.random.rand(100, 10),
}
features = {
OBS_IMAGE: {"dtype": "image"},
OBS_AUDIO: {"dtype": "audio"},
OBS_STATE: {"dtype": "numeric"},
}
with (
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
):
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
stats = compute_episode_stats(episode_data, features)
assert OBS_IMAGE in stats and OBS_AUDIO in stats and OBS_STATE in stats
assert stats[OBS_IMAGE]["count"].item() == estimate_num_samples(100)
assert stats[OBS_AUDIO]["count"].item() == estimate_num_samples(16000)
assert stats[OBS_STATE]["count"].item() == estimate_num_samples(100)
assert OBS_IMAGE in stats and OBS_STATE in stats
assert stats[OBS_IMAGE]["count"].item() == 100
assert stats[OBS_STATE]["count"].item() == 100
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
assert stats[OBS_AUDIO]["mean"].shape == (1, 2)
def test_assert_type_and_shape_valid():

View File

@@ -24,7 +24,6 @@ import torch
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
from soundfile import write
import lerobot
from lerobot.configs.default import DatasetConfig
@@ -36,7 +35,6 @@ from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
@@ -47,13 +45,7 @@ from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from tests.fixtures.constants import (
DEFAULT_SAMPLE_RATE,
DUMMY_AUDIO_CHANNELS,
DUMMY_CHW,
DUMMY_HWC,
DUMMY_REPO_ID,
)
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel
@@ -74,36 +66,6 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@pytest.fixture
def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory):
features = {
"audio": {
"dtype": "audio",
"shape": (1, DUMMY_AUDIO_CHANNELS),
"names": [
"channels",
],
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi")
@pytest.fixture
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
"audio": {
"dtype": "audio",
"shape": (1, DUMMY_AUDIO_CHANNELS),
"names": [
"channels",
],
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
@@ -458,78 +420,6 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
)
def test_add_frame_audio_array(audio_dataset_le_kiwi):
dataset = audio_dataset_le_kiwi
dataset.add_frame(
{
"audio": np.random.rand(
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
)
},
task="Dummy task",
)
dataset.save_episode()
assert dataset[0]["audio"].shape == torch.Size(
(
DUMMY_AUDIO_CHANNELS,
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
)
)
def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi):
dataset = audio_dataset_le_kiwi
with pytest.raises(ValueError):
dataset.add_frame(
{
"audio": np.random.rand(
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99
)
},
task="Dummy task",
)
def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi):
dataset = audio_dataset_le_kiwi
with pytest.raises(ValueError):
dataset.add_frame(
{"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)},
task="Dummy task",
)
def test_add_frame_audio_file(audio_dataset):
dataset = audio_dataset
dataset.add_frame(
{
"audio": np.random.rand(
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
)
},
task="Dummy task",
)
# Create the audio file that should be created in the background by the Microphone class
for audio_key in dataset.meta.audio_keys:
fpath = dataset.writer._get_raw_audio_file_path(0, audio_key)
fpath.parent.mkdir(parents=True, exist_ok=True)
write(
fpath,
np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS),
DEFAULT_SAMPLE_RATE,
)
dataset.save_episode()
assert dataset[0]["audio"].shape == torch.Size(
(
DUMMY_AUDIO_CHANNELS,
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
)
)
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
@@ -569,7 +459,6 @@ def test_factory(env_name, repo_id, policy_name):
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.meta.camera_keys
audio_keys = dataset.meta.audio_keys
item = dataset[0]
@@ -612,11 +501,6 @@ def test_factory(env_name, repo_id, policy_name):
# test c,h,w
assert item[key].shape[0] == 3, f"{key}"
for key in audio_keys:
assert item[key].dtype == torch.float32, f"{key}"
assert item[key].max() <= 1.0, f"{key}"
assert item[key].min() >= -1.0, f"{key}"
if delta_timestamps is not None:
# test missing keys in delta_timestamps
for key in delta_timestamps:

143
tests/envs/test_dispatch.py Normal file
View File

@@ -0,0 +1,143 @@
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
import gymnasium as gym
import pytest
from gymnasium.envs.registration import register, registry as gym_registry
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
logger = logging.getLogger(__name__)
def test_registry_all_types():
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
known = list(EnvConfig.get_known_choices().keys())
assert len(known) >= 6
for t in known:
cfg = make_env_config(t)
assert cfg.type == t
def test_unknown_type():
with pytest.raises(ValueError, match="not registered"):
make_env_config("nonexistent")
def test_identity_processors():
"""Base class get_env_processors() returns identity pipelines."""
cfg = make_env_config("aloha")
pre, post = cfg.get_env_processors()
assert len(pre.steps) == 0 and len(post.steps) == 0
def test_delegation():
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
sentinel = {"delegated": {0: "marker"}}
fake = type(
"Fake",
(),
{
"hub_path": None,
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
},
)()
result = make_env(fake, n_envs=1)
assert result is sentinel
def test_processors_delegation():
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
from lerobot.configs.policies import PreTrainedConfig
cfg = make_env_config("aloha")
pre, post = make_env_pre_post_processors(cfg, PreTrainedConfig())
assert len(pre.steps) == 0
def test_base_create_envs():
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
gym_id = "_dispatch_test/CartPole-v99"
if gym_id not in gym_registry:
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
@EnvConfig.register_subclass("_dispatch_base_test")
@dataclass
class _Env(EnvConfig):
task: str = "CartPole-v99"
fps: int = 10
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def package_name(self):
return "_dispatch_test"
@property
def gym_id(self):
return gym_id
@property
def gym_kwargs(self):
return {}
try:
envs = _Env().create_envs(n_envs=2)
assert "_dispatch_base_test" in envs
env = envs["_dispatch_base_test"][0]
assert isinstance(env, gym.vector.SyncVectorEnv)
assert env.num_envs == 2
env.close()
finally:
if gym_id in gym_registry:
del gym_registry[gym_id]
def test_custom_create_envs_override():
"""A custom EnvConfig subclass can override create_envs()."""
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
@EnvConfig.register_subclass("_dispatch_custom_test")
@dataclass
class _Env(EnvConfig):
task: str = "x"
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def gym_kwargs(self):
return {}
def create_envs(self, n_envs, use_async_envs=False):
return {"custom_suite": {0: mock_vec}}
try:
result = make_env(_Env(), n_envs=1)
assert "custom_suite" in result
finally:
mock_vec.close()
def test_custom_get_env_processors_override():
"""A custom EnvConfig subclass can override get_env_processors()."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
@EnvConfig.register_subclass("_dispatch_proc_test")
@dataclass
class _Env(EnvConfig):
task: str = "x"
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def gym_kwargs(self):
return {}
def get_env_processors(self):
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
pre, post = _Env().get_env_processors()
assert isinstance(pre, PolicyProcessorPipeline)

View File

@@ -40,18 +40,5 @@ DUMMY_VIDEO_INFO = {
"video.is_depth_map": False,
"has_audio": False,
}
DUMMY_MICROPHONE_FEATURES = {
"laptop": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
"phone": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
}
DEFAULT_SAMPLE_RATE = 48000
DUMMY_AUDIO_CHANNELS = 2
DUMMY_AUDIO_INFO = {
"has_audio": True,
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
"audio.codec": "aac",
"audio.channels": DUMMY_AUDIO_CHANNELS,
"audio.channel_layout": "stereo",
}
DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3)

View File

@@ -31,7 +31,6 @@ from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_AUDIO_PATH,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -44,7 +43,6 @@ from lerobot.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_MICROPHONE_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
@@ -133,7 +131,6 @@ def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
@@ -145,7 +142,6 @@ def features_factory():
return {
**motor_features,
**camera_ft,
**audio_features,
**DEFAULT_FEATURES,
}
@@ -162,19 +158,16 @@ def info_factory(features_factory):
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
total_audio: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
data_path: str = DEFAULT_DATA_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
audio_path: str = DEFAULT_AUDIO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
use_videos: bool = True,
) -> dict:
features = features_factory(motor_features, camera_features, audio_features, use_videos)
features = features_factory(motor_features, camera_features, use_videos)
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
@@ -182,7 +175,6 @@ def info_factory(features_factory):
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"total_audio": total_audio,
"chunks_size": chunks_size,
"data_files_size_in_mb": data_files_size_in_mb,
"video_files_size_in_mb": video_files_size_in_mb,
@@ -190,7 +182,6 @@ def info_factory(features_factory):
"splits": {},
"data_path": data_path,
"video_path": video_path if use_videos else None,
"audio_path": audio_path,
"features": features,
}
@@ -214,14 +205,6 @@ def stats_factory():
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
"count": [10],
}
elif dtype == "audio":
stats[key] = {
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
"count": [10],
}
else:
stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(),

View File

@@ -1,532 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from pathlib import Path
import numpy as np
import pytest
from soundfile import read
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig
from lerobot.microphones.portaudio.interface_sounddevice_sdk import (
FakeSounddeviceSDKAdapter,
SounddeviceSDKAdapter,
)
from lerobot.microphones.portaudio.microphone_portaudio import PortAudioMicrophone
from lerobot.microphones.utils import async_microphones_start_recording, async_microphones_stop_recording
from lerobot.utils.errors import (
DeviceAlreadyConnectedError,
DeviceAlreadyRecordingError,
DeviceNotConnectedError,
DeviceNotRecordingError,
)
from lerobot.utils.robot_utils import precise_sleep
MODULE_PATH = "lerobot.microphones.portaudio.microphone_portaudio"
RECORDING_DURATION = 1.0
LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS = (
os.getenv("LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS", "False").lower() == "true"
)
@pytest.fixture
def test_sdk():
"""Fixture to provide either real or fake SDK based on environment variable."""
if LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS:
return SounddeviceSDKAdapter()
else:
return FakeSounddeviceSDKAdapter()
# Configuration Tests
def test_config_creation():
"""Test creating a valid configuration."""
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000, channels=[1, 2])
assert config.microphone_index == 0
assert config.sample_rate == 48000
assert config.channels == [1, 2]
def test_config_creation_missing_microphone_index():
"""Test creating a configuration with missing microphone index."""
with pytest.raises(TypeError):
PortAudioMicrophoneConfig(sample_rate=48000, channels=[1, 2])
def test_config_creation_missing_sample_rate():
"""Test creating a configuration with missing sample rate."""
config = PortAudioMicrophoneConfig(microphone_index=0, channels=[1, 2])
assert config.sample_rate is None
def test_config_creation_missing_channels():
"""Test creating a configuration with missing channels."""
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000)
assert config.channels is None
@pytest.fixture
def default_config(test_sdk):
"""Fixture to provide a default configuration for input devices."""
device_info = test_sdk.query_devices(kind="input")
return PortAudioMicrophoneConfig(
microphone_index=device_info["index"],
sample_rate=device_info["default_samplerate"],
channels=np.arange(device_info["max_input_channels"]) + 1,
)
# Microphone Tests
def test_find_microphones(test_sdk):
"""Test finding microphones."""
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=test_sdk)
for microphone in microphones:
assert isinstance(microphone["index"], int)
assert isinstance(microphone["name"], str)
assert isinstance(microphone["sample_rate"], int)
assert isinstance(microphone["channels"], np.ndarray)
assert len(microphone["channels"]) > 0
def test_init_defaults(default_config, test_sdk):
"""Test microphone initialization with defaults."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
device_info = test_sdk.query_devices(kind="input")
assert microphone is not None
assert microphone.microphone_index == device_info["index"]
assert microphone.sample_rate == device_info["default_samplerate"]
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
assert not microphone.is_connected
assert not microphone.is_recording
def test_connect_success(default_config, test_sdk):
"""Test successful connection."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
assert microphone.is_connected
assert not microphone.is_recording
assert not microphone.is_writing
def test_connect_empty_config(default_config, test_sdk):
"""Test connection with empty config values."""
config = deepcopy(default_config)
config.sample_rate = None
config.channels = None
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
microphone.connect()
device_info = test_sdk.query_devices(kind="input")
assert microphone.sample_rate == device_info["default_samplerate"]
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
def test_connect_already_connected(default_config, test_sdk):
"""Test connecting when already connected."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
with pytest.raises(DeviceAlreadyConnectedError):
microphone.connect()
def test_connect_invalid_device(test_sdk):
"""Test connecting with invalid device (output device)."""
device_info = test_sdk.query_devices(kind="output")
config = PortAudioMicrophoneConfig(
microphone_index=device_info["index"],
sample_rate=device_info["default_samplerate"],
channels=np.arange(device_info["max_input_channels"]) + 1,
)
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
with pytest.raises(RuntimeError):
microphone.connect()
def test_connect_invalid_index(default_config, test_sdk):
"""Test connecting with invalid device index."""
config = deepcopy(default_config)
config.microphone_index = -1
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
with pytest.raises(RuntimeError):
microphone.connect()
def test_connect_invalid_sample_rate(default_config, test_sdk):
"""Test connecting with invalid sample rate."""
config = deepcopy(default_config)
config.sample_rate = -1
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
with pytest.raises(RuntimeError):
microphone.connect()
def test_connect_float_sample_rate(default_config, test_sdk):
"""Test connecting with float sample rate."""
config = deepcopy(default_config)
config.sample_rate = int(config.sample_rate) - 0.5
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
microphone.connect()
assert isinstance(microphone.sample_rate, int)
assert microphone.sample_rate == int(config.sample_rate)
def test_connect_lower_sample_rate(default_config, test_sdk):
"""Test connecting with lower sample rate."""
config = deepcopy(default_config)
config.sample_rate = 1000 # Lowest possible sample rate
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
microphone.connect()
assert microphone.sample_rate == 1000
def test_connect_invalid_channels(default_config, test_sdk):
"""Test connecting with invalid channels."""
config = deepcopy(default_config)
config.channels = np.append(default_config.channels, -1)
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
with pytest.raises(RuntimeError):
microphone.connect()
def test_disconnect_success(default_config, test_sdk):
"""Test successful disconnection."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.disconnect()
assert not microphone.is_connected
assert not microphone.is_recording
assert not microphone.is_writing
def test_disconnect_not_connected(default_config, test_sdk):
"""Test disconnecting when not connected."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
with pytest.raises(DeviceNotConnectedError):
microphone.disconnect()
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_start_recording_success(default_config, test_sdk, multiprocessing):
"""Test successful recording start."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(multiprocessing=multiprocessing)
assert microphone.is_recording
assert microphone.is_connected
assert not microphone.is_writing
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_recording_not_connected(default_config, test_sdk, multiprocessing):
"""Test starting recording when not connected."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
with pytest.raises(DeviceNotConnectedError):
microphone.start_recording(multiprocessing=multiprocessing)
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_start_recording_already_recording(default_config, test_sdk, multiprocessing):
"""Test starting recording when already recording."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(multiprocessing=multiprocessing)
with pytest.raises(DeviceAlreadyRecordingError):
microphone.start_recording(multiprocessing=multiprocessing)
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_start_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
"""Test successful writing start."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
assert microphone.is_recording
assert microphone.is_connected
assert microphone.is_writing
assert (tmp_path / "test.wav").exists()
(tmp_path / "test.wav").unlink()
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_start_writing_file_already_exists_no_overwrite(tmp_path, default_config, test_sdk, multiprocessing):
"""Test writing with file that already exists."""
(tmp_path / "test.wav").touch()
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
with pytest.raises(FileExistsError):
microphone.start_recording(
output_file=tmp_path / "test.wav", multiprocessing=multiprocessing, overwrite=False
)
(tmp_path / "test.wav").unlink()
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_stop_recording_success(default_config, test_sdk, multiprocessing):
"""Test successful recording stop."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(multiprocessing=multiprocessing)
precise_sleep(RECORDING_DURATION)
microphone.stop_recording()
assert not microphone.is_recording
assert microphone.is_connected
assert not microphone.is_writing
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_stop_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
"""Test successful writing stop."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
precise_sleep(RECORDING_DURATION)
microphone.stop_recording()
assert not microphone.is_recording
assert microphone.is_connected
assert not microphone.is_writing
assert (tmp_path / "test.wav").exists()
(tmp_path / "test.wav").unlink()
def test_stop_recording_not_connected(default_config, test_sdk):
"""Test stopping recording when not connected."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
with pytest.raises(DeviceNotConnectedError):
microphone.stop_recording()
def test_stop_recording_not_recording(default_config, test_sdk):
"""Test stopping recording when not recording."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
with pytest.raises(DeviceNotRecordingError):
microphone.stop_recording()
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_disconnect_while_recording(default_config, test_sdk, multiprocessing):
"""Test disconnecting while recording."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(multiprocessing=multiprocessing)
precise_sleep(RECORDING_DURATION)
microphone.disconnect()
assert not microphone.is_connected
assert not microphone.is_recording
assert not microphone.is_writing
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_disconnect_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
"""Test disconnecting while writing."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
precise_sleep(RECORDING_DURATION)
microphone.disconnect()
assert not microphone.is_connected
assert not microphone.is_recording
assert not microphone.is_writing
assert Path(tmp_path / "test.wav").exists()
(tmp_path / "test.wav").unlink()
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_read_success(default_config, test_sdk, multiprocessing):
"""Test successful reading of audio data."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(multiprocessing=multiprocessing)
precise_sleep(RECORDING_DURATION)
data = microphone.read()
device_info = test_sdk.query_devices(kind="input")
assert data is not None
assert data.shape[1] == len(default_config.channels)
assert (
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
)
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
"""Test successful writing to file."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
precise_sleep(RECORDING_DURATION)
microphone.stop_recording()
data, samplerate = read(tmp_path / "test.wav")
device_info = test_sdk.query_devices(kind="input")
assert samplerate == default_config.sample_rate
assert data.shape[1] == len(default_config.channels)
assert (
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
)
(tmp_path / "test.wav").unlink()
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_read_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
"""Test reading while writing."""
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
microphone.connect()
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
precise_sleep(RECORDING_DURATION)
read_data = microphone.read()
microphone.stop_recording()
writing_data, _ = read(tmp_path / "test.wav")
device_info = test_sdk.query_devices(kind="input")
assert (
abs(writing_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
)
assert (
abs(read_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
)
(tmp_path / "test.wav").unlink()
def test_async_start_recording(default_config, test_sdk):
"""Test async recording start."""
microphones = {
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
}
for microphone in microphones.values():
microphone.connect()
async_microphones_start_recording(microphones)
for microphone in microphones.values():
assert microphone.is_recording
assert microphone.is_connected
assert not microphone.is_writing
def test_async_start_writing(tmp_path, default_config, test_sdk):
"""Test async writing start."""
microphones = {
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
}
for microphone in microphones.values():
microphone.connect()
async_microphones_start_recording(
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
)
for microphone in microphones.values():
assert microphone.is_recording
assert microphone.is_connected
assert microphone.is_writing
assert Path(tmp_path / "test_1.wav").exists()
assert Path(tmp_path / "test_2.wav").exists()
(tmp_path / "test_1.wav").unlink()
(tmp_path / "test_2.wav").unlink()
def test_async_stop_recording(default_config, test_sdk):
"""Test async recording stop."""
microphones = {
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
}
for microphone in microphones.values():
microphone.connect()
async_microphones_start_recording(microphones)
async_microphones_stop_recording(microphones)
for microphone in microphones.values():
assert not microphone.is_recording
assert microphone.is_connected
assert not microphone.is_writing
def test_async_stop_writing(tmp_path, default_config, test_sdk):
"""Test async writing stop."""
microphones = {
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
}
for microphone in microphones.values():
microphone.connect()
async_microphones_start_recording(
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
)
async_microphones_stop_recording(microphones)
for microphone in microphones.values():
assert not microphone.is_recording
assert microphone.is_connected
assert not microphone.is_writing
assert Path(tmp_path / "test_1.wav").exists()
assert Path(tmp_path / "test_2.wav").exists()
(tmp_path / "test_1.wav").unlink()
(tmp_path / "test_2.wav").unlink()

View File

@@ -1,508 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import time
from multiprocessing import Event, Process, Queue
import numpy as np
import pytest
from lerobot.utils.shared_array import SharedArray
def writer_process(shared_array, data_queue, stop_event, barrier, process_id):
"""Writer process that continuously writes data to shared array."""
local_array = shared_array.get_local_array()
# Wait for all processes to be ready
barrier.wait()
write_count = 0
while not stop_event.is_set() and write_count < 10:
# Generate unique data for this process and write iteration
data = np.full((5, 2), process_id * 100 + write_count, dtype=np.float32)
try:
shared_array.write(local_array, data)
data_queue.put(f"writer_{process_id}_wrote_{write_count}")
write_count += 1
time.sleep(0.01) # Small delay to allow race conditions
except IndexError:
# Array is full, stop writing
break
def reader_process(shared_array, data_queue, stop_event, barrier, process_id):
"""Reader process that continuously reads data from shared array."""
local_array = shared_array.get_local_array()
# Wait for all processes to be ready
barrier.wait()
read_count = 0
while not stop_event.is_set() and read_count < 5:
time.sleep(0.02) # Allow some writes to accumulate
data = shared_array.read(local_array, flush=True)
data_queue.put(f"reader_{process_id}_read_{len(data)}_items")
read_count += 1
def stress_writer_process(shared_array, data_queue, stop_event, barrier, process_id):
"""High-frequency writer process for stress testing."""
local_array = shared_array.get_local_array()
barrier.wait()
write_count = 0
while not stop_event.is_set() and write_count < 50:
# Write single row at a time for more frequent operations
data = np.array([[process_id, write_count]], dtype=np.float32)
try:
shared_array.write(local_array, data)
write_count += 1
# No sleep - stress test
except IndexError:
break
data_queue.put(f"stress_writer_{process_id}_completed_{write_count}")
# Basic functionality tests
def test_shared_array_creation():
"""Test basic SharedArray creation and properties."""
shape = (100, 4)
dtype = np.float32
shared_array = SharedArray(shape=shape, dtype=dtype)
assert shared_array.shape == shape
assert shared_array.dtype == dtype
assert shared_array.read_index.value == 0
# Clean up
shared_array.delete()
def test_local_array_access():
"""Test getting local array instances."""
shape = (50, 2)
shared_array = SharedArray(shape=shape, dtype=np.float32)
local_array = shared_array.get_local_array()
assert local_array.shape == shape
assert local_array.dtype == np.float32
assert isinstance(local_array, np.ndarray)
# Test that we can get multiple local array instances
local_array2 = shared_array.get_local_array()
assert local_array2.shape == shape
shared_array.delete()
def test_write_and_read_single_process():
"""Test basic write and read operations in single process."""
shape = (20, 3)
shared_array = SharedArray(shape=shape, dtype=np.float32)
local_array = shared_array.get_local_array()
# Write some data
data1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
shared_array.write(local_array, data1)
assert shared_array.read_index.value == 2
# Write more data
data2 = np.array([[7, 8, 9]], dtype=np.float32)
shared_array.write(local_array, data2)
assert shared_array.read_index.value == 3
# Read all data
read_data = shared_array.read(local_array, flush=False)
expected = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
np.testing.assert_array_equal(read_data, expected)
# Read with flush
read_data_flush = shared_array.read(local_array, flush=True)
np.testing.assert_array_equal(read_data_flush, expected)
assert shared_array.read_index.value == 0
shared_array.delete()
def test_array_overflow():
"""Test behavior when writing more data than array capacity."""
shape = (5, 2) # Small array
shared_array = SharedArray(shape=shape, dtype=np.float32)
local_array = shared_array.get_local_array()
# Fill the array
data = np.ones((5, 2), dtype=np.float32)
shared_array.write(local_array, data)
# Try to write more data - should raise IndexError
with pytest.raises(ValueError):
extra_data = np.ones((2, 2), dtype=np.float32)
shared_array.write(local_array, extra_data)
shared_array.delete()
def test_reset_functionality():
"""Test the reset method."""
shape = (10, 2)
shared_array = SharedArray(shape=shape, dtype=np.float32)
local_array = shared_array.get_local_array()
# Write some data
data = np.ones((3, 2), dtype=np.float32)
shared_array.write(local_array, data)
assert shared_array.read_index.value == 3
# Reset
shared_array.reset()
assert shared_array.read_index.value == 0
# Read should return empty array
read_data = shared_array.read(local_array, flush=False)
assert len(read_data) == 0
shared_array.delete()
# Multi-process tests
def test_single_writer_single_reader():
"""Test basic writer-reader scenario with one process each."""
shape = (100, 2)
shared_array = SharedArray(shape=shape, dtype=np.float32)
data_queue = Queue()
stop_event = Event()
barrier = multiprocessing.Barrier(2) # Writer + reader
# Start writer process
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
# Start reader process
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
writer.start()
reader.start()
# Let them run for a bit
time.sleep(0.5)
stop_event.set()
# Wait for completion
writer.join(timeout=2.0)
reader.join(timeout=2.0)
# Verify both processes completed
assert not writer.is_alive()
assert not reader.is_alive()
# Check that we got messages from both processes
messages = []
while not data_queue.empty():
messages.append(data_queue.get())
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
assert len(writer_messages) > 0
assert len(reader_messages) > 0
shared_array.delete()
def test_multiple_writers_single_reader():
"""Test multiple writers with single reader - check for race conditions."""
shape = (200, 2)
shared_array = SharedArray(shape=shape, dtype=np.float32)
data_queue = Queue()
stop_event = Event()
num_writers = 3
barrier = multiprocessing.Barrier(num_writers + 1) # Writers + reader
processes = []
# Start multiple writer processes
for i in range(num_writers):
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
processes.append(writer)
writer.start()
# Start reader process
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
processes.append(reader)
reader.start()
# Let them run
time.sleep(1.0)
stop_event.set()
# Wait for all processes
for process in processes:
process.join(timeout=3.0)
assert not process.is_alive()
# Verify we got messages from all processes
messages = []
while not data_queue.empty():
messages.append(data_queue.get())
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
# Should have messages from all writers
assert len(writer_messages) >= num_writers
assert len(reader_messages) > 0
shared_array.delete()
def test_data_integrity_with_concurrent_access():
"""Test that data integrity is maintained under concurrent access using standard reader/writer processes."""
shape = (500, 2) # Use standard 2-column format
shared_array = SharedArray(shape=shape, dtype=np.float32)
data_queue = Queue()
stop_event = Event()
barrier = multiprocessing.Barrier(3) # 2 writers + 1 reader
# Start two writer processes
writer1 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
writer2 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 2))
# Start one reader process
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
writer1.start()
writer2.start()
reader.start()
# Let them run for integrity test duration
time.sleep(1.0)
stop_event.set()
# Wait for completion
writer1.join(timeout=3.0)
writer2.join(timeout=3.0)
reader.join(timeout=3.0)
# Verify all processes completed successfully
assert not writer1.is_alive()
assert not writer2.is_alive()
assert not reader.is_alive()
# Verify data integrity by checking messages
messages = []
while not data_queue.empty():
messages.append(data_queue.get())
writer1_messages = [msg for msg in messages if "writer_1_wrote" in msg]
writer2_messages = [msg for msg in messages if "writer_2_wrote" in msg]
reader_messages = [msg for msg in messages if "reader_1_read" in msg]
# Verify both writers wrote data
assert len(writer1_messages) > 0
assert len(writer2_messages) > 0
# Verify reader read data
assert len(reader_messages) > 0
# Verify the shared array is in a consistent state
local_array = shared_array.get_local_array()
final_data = shared_array.read(local_array, flush=False)
# Should have some data written by the writers
assert len(final_data) >= 0 # Could be empty if reader flushed everything
# Should not exceed array capacity
assert len(final_data) <= shape[0]
# If there's data, verify it contains the expected writer signatures
if len(final_data) > 0:
# Data should contain values like 100, 101, 102... (writer 1) or 200, 201, 202... (writer 2)
unique_values = np.unique(final_data.flatten())
writer1_values = unique_values[(unique_values >= 100) & (unique_values < 200)]
writer2_values = unique_values[(unique_values >= 200) & (unique_values < 300)]
# Should have data from at least one writer
assert len(writer1_values) > 0 or len(writer2_values) > 0
shared_array.delete()
def test_stress_test_high_frequency_operations():
"""Stress test with high frequency read/write operations."""
shape = (1000, 2)
shared_array = SharedArray(shape=shape, dtype=np.float32)
data_queue = Queue()
stop_event = Event()
num_writers = 4
barrier = multiprocessing.Barrier(num_writers)
processes = []
# Start multiple high-frequency writers
for i in range(num_writers):
writer = Process(
target=stress_writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)
)
processes.append(writer)
writer.start()
# Let them run for stress test duration
time.sleep(0.5)
stop_event.set()
# Wait for completion
for process in processes:
process.join(timeout=3.0)
assert not process.is_alive()
# Verify all writers completed successfully
messages = []
while not data_queue.empty():
messages.append(data_queue.get())
completed_messages = [msg for msg in messages if "completed" in msg]
assert len(completed_messages) == num_writers
# Verify the shared array is in a consistent state
local_array = shared_array.get_local_array()
final_data = shared_array.read(local_array, flush=False)
# Should have some data written
assert len(final_data) > 0
# Should not exceed array capacity
assert len(final_data) <= shape[0]
shared_array.delete()
def test_concurrent_readers():
"""Test multiple concurrent readers with writers to ensure thread safety."""
shape = (200, 2)
shared_array = SharedArray(shape=shape, dtype=np.float32)
data_queue = Queue()
stop_event = Event()
num_readers = 3
num_writers = 2
barrier = multiprocessing.Barrier(num_readers + num_writers)
processes = []
# Start multiple writer processes to generate data
for i in range(num_writers):
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
processes.append(writer)
writer.start()
# Start multiple reader processes
for i in range(num_readers):
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
processes.append(reader)
reader.start()
# Let them run to test concurrent access
time.sleep(1.0)
stop_event.set()
# Wait for all processes to complete
for process in processes:
process.join(timeout=3.0)
assert not process.is_alive()
# Verify all readers and writers completed
messages = []
while not data_queue.empty():
messages.append(data_queue.get())
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
# Should have messages from all readers and writers
assert len(reader_messages) >= num_readers
assert len(writer_messages) >= num_writers
# Verify different readers generated different messages (proving they ran concurrently)
reader_ids = set()
for msg in reader_messages:
# Extract reader ID from message like "reader_1_read_5_items"
parts = msg.split("_")
if len(parts) >= 2:
reader_ids.add(parts[1])
assert len(reader_ids) == num_readers # All readers should have participated
shared_array.delete()
def test_edge_case_empty_reads():
"""Test reading from empty array and after flushes."""
shape = (10, 2)
shared_array = SharedArray(shape=shape, dtype=np.float32)
local_array = shared_array.get_local_array()
# Read from empty array
empty_data = shared_array.read(local_array, flush=False)
assert len(empty_data) == 0
# Write some data
data = np.ones((3, 2), dtype=np.float32)
shared_array.write(local_array, data)
# Read with flush
read_data = shared_array.read(local_array, flush=True)
assert len(read_data) == 3
# Read again after flush - should be empty
empty_again = shared_array.read(local_array, flush=False)
assert len(empty_again) == 0
shared_array.delete()
def test_different_dtypes():
"""Test SharedArray with different numpy dtypes."""
dtypes_to_test = [np.float32, np.float64, np.int32, np.int16]
for dtype in dtypes_to_test:
shape = (20, 2)
shared_array = SharedArray(shape=shape, dtype=dtype)
local_array = shared_array.get_local_array()
assert local_array.dtype == dtype
# Write and read data of this dtype
data = np.ones((5, 2), dtype=dtype)
shared_array.write(local_array, data)
read_data = shared_array.read(local_array, flush=True)
assert read_data.dtype == dtype
assert len(read_data) == 5
shared_array.delete()

View File

@@ -0,0 +1,559 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ActionInterpolator and its interaction with ActionQueue (RTC)."""
import pytest
import torch
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
# ====================== Fixtures ======================
@pytest.fixture
def interp2():
"""Create an ActionInterpolator with multiplier=2."""
return ActionInterpolator(multiplier=2)
@pytest.fixture
def interp3():
"""Create an ActionInterpolator with multiplier=3."""
return ActionInterpolator(multiplier=3)
# ====================== Initialization Tests ======================
def test_interpolator_multiplier_1_no_interpolation():
"""Test multiplier=1 creates a disabled interpolator."""
interp = ActionInterpolator(multiplier=1)
assert interp.multiplier == 1
assert not interp.enabled
def test_interpolator_multiplier_2_enabled():
"""Test multiplier=2 creates an enabled interpolator."""
interp = ActionInterpolator(multiplier=2)
assert interp.multiplier == 2
assert interp.enabled
def test_interpolator_multiplier_0_raises():
"""Test multiplier=0 raises ValueError."""
with pytest.raises(ValueError, match="multiplier must be >= 1"):
ActionInterpolator(multiplier=0)
def test_interpolator_negative_multiplier_raises():
"""Test negative multiplier raises ValueError."""
with pytest.raises(ValueError, match="multiplier must be >= 1"):
ActionInterpolator(multiplier=-1)
def test_interpolator_default_multiplier_is_1():
"""Test default multiplier is 1 (disabled)."""
interp = ActionInterpolator()
assert interp.multiplier == 1
assert not interp.enabled
# ====================== needs_new_action Tests ======================
def test_needs_new_action_true_initially(interp2):
"""Test needs_new_action() returns True before any action is added."""
assert interp2.needs_new_action()
def test_needs_new_action_false_after_add(interp2):
"""Test needs_new_action() returns False right after add()."""
interp2.add(torch.tensor([1.0, 2.0]))
assert not interp2.needs_new_action()
def test_needs_new_action_true_after_buffer_exhausted(interp2):
"""Test needs_new_action() returns True after consuming all buffered actions."""
interp2.add(torch.tensor([1.0, 2.0]))
interp2.get()
assert interp2.needs_new_action()
def test_needs_new_action_true_after_all_interpolated_consumed(interp2):
"""Test needs_new_action() tracks interpolated sub-steps correctly."""
interp2.add(torch.tensor([0.0, 0.0]))
interp2.get()
assert interp2.needs_new_action()
interp2.add(torch.tensor([2.0, 4.0]))
interp2.get()
assert not interp2.needs_new_action()
interp2.get()
assert interp2.needs_new_action()
# ====================== Passthrough Tests (multiplier=1) ======================
def test_passthrough_single_action_returned_as_is():
"""Test multiplier=1 returns the action unchanged."""
interp = ActionInterpolator(multiplier=1)
action = torch.tensor([3.0, 5.0])
interp.add(action)
result = interp.get()
assert result is not None
torch.testing.assert_close(result, action)
def test_passthrough_none_after_single_get():
"""Test multiplier=1 returns None after consuming the single action."""
interp = ActionInterpolator(multiplier=1)
interp.add(torch.tensor([1.0]))
interp.get()
assert interp.get() is None
def test_passthrough_sequential_actions():
"""Test multiplier=1 passes through consecutive actions one at a time."""
interp = ActionInterpolator(multiplier=1)
for val in [1.0, 2.0, 3.0]:
action = torch.tensor([val])
interp.add(action)
result = interp.get()
torch.testing.assert_close(result, action)
assert interp.get() is None
# ====================== Interpolation Tests (multiplier=2) ======================
def test_interpolation_2x_first_action_no_interpolation(interp2):
"""Test first action has no previous, so buffer is just [action]."""
interp2.add(torch.tensor([0.0, 0.0]))
result = interp2.get()
torch.testing.assert_close(result, torch.tensor([0.0, 0.0]))
assert interp2.get() is None
def test_interpolation_2x_second_action_produces_two_steps(interp2):
"""Test second action produces 2 interpolated sub-steps."""
interp2.add(torch.tensor([0.0, 0.0]))
interp2.get()
interp2.add(torch.tensor([2.0, 4.0]))
step1 = interp2.get()
step2 = interp2.get()
torch.testing.assert_close(step1, torch.tensor([1.0, 2.0]))
torch.testing.assert_close(step2, torch.tensor([2.0, 4.0]))
assert interp2.get() is None
def test_interpolation_2x_three_consecutive_actions(interp2):
"""Test interpolation across three consecutive actions."""
a0 = torch.tensor([0.0])
a1 = torch.tensor([4.0])
a2 = torch.tensor([10.0])
interp2.add(a0)
torch.testing.assert_close(interp2.get(), a0)
interp2.add(a1)
torch.testing.assert_close(interp2.get(), torch.tensor([2.0]))
torch.testing.assert_close(interp2.get(), torch.tensor([4.0]))
interp2.add(a2)
torch.testing.assert_close(interp2.get(), torch.tensor([7.0]))
torch.testing.assert_close(interp2.get(), torch.tensor([10.0]))
# ====================== Interpolation Tests (multiplier=3) ======================
def test_interpolation_3x_produces_three_steps(interp3):
"""Test multiplier=3 produces 3 interpolated sub-steps."""
interp3.add(torch.tensor([0.0, 0.0]))
interp3.get()
interp3.add(torch.tensor([3.0, 6.0]))
s1 = interp3.get()
s2 = interp3.get()
s3 = interp3.get()
torch.testing.assert_close(s1, torch.tensor([1.0, 2.0]))
torch.testing.assert_close(s2, torch.tensor([2.0, 4.0]))
torch.testing.assert_close(s3, torch.tensor([3.0, 6.0]))
assert interp3.get() is None
def test_interpolation_3x_last_step_equals_target(interp3):
"""Test last interpolated step equals the target action exactly."""
interp3.add(torch.tensor([10.0]))
interp3.get()
target = torch.tensor([100.0])
interp3.add(target)
interp3.get()
interp3.get()
last = interp3.get()
torch.testing.assert_close(last, target)
# ====================== Reset Tests ======================
def test_reset_clears_buffer(interp2):
"""Test reset() clears the action buffer."""
interp2.add(torch.tensor([1.0]))
interp2.reset()
assert interp2.needs_new_action()
assert interp2.get() is None
def test_reset_clears_prev(interp2):
"""Test after reset, next add produces single-element buffer (no prev)."""
interp2.add(torch.tensor([0.0]))
interp2.get()
interp2.add(torch.tensor([10.0]))
interp2.get()
interp2.get()
interp2.reset()
interp2.add(torch.tensor([5.0]))
result = interp2.get()
torch.testing.assert_close(result, torch.tensor([5.0]))
assert interp2.get() is None
def test_reset_episode_boundary(interp2):
"""Test reset between two simulated episodes."""
interp2.add(torch.tensor([0.0]))
interp2.get()
interp2.add(torch.tensor([10.0]))
interp2.get()
interp2.get()
interp2.reset()
interp2.add(torch.tensor([100.0]))
result = interp2.get()
torch.testing.assert_close(result, torch.tensor([100.0]))
assert interp2.get() is None
# ====================== get_control_interval Tests ======================
def test_control_interval_30fps_multiplier_1():
"""Test control interval at 30fps with no interpolation."""
interp = ActionInterpolator(multiplier=1)
assert interp.get_control_interval(30.0) == pytest.approx(1.0 / 30.0)
def test_control_interval_30fps_multiplier_2(interp2):
"""Test control interval at 30fps with 2x interpolation."""
assert interp2.get_control_interval(30.0) == pytest.approx(1.0 / 60.0)
def test_control_interval_30fps_multiplier_3(interp3):
"""Test control interval at 30fps with 3x interpolation."""
assert interp3.get_control_interval(30.0) == pytest.approx(1.0 / 90.0)
def test_control_interval_60fps_multiplier_2(interp2):
"""Test control interval at 60fps with 2x interpolation."""
assert interp2.get_control_interval(60.0) == pytest.approx(1.0 / 120.0)
# ====================== get() on Empty Tests ======================
def test_get_returns_none_before_any_add():
"""Test get() returns None when no action has been added."""
interp = ActionInterpolator(multiplier=2)
assert interp.get() is None
def test_get_returns_none_after_reset(interp2):
"""Test get() returns None after reset."""
interp2.add(torch.tensor([1.0]))
interp2.reset()
assert interp2.get() is None
# ====================== Multi-Dimensional Action Tests ======================
def test_6dof_interpolation(interp2):
"""Test interpolation works correctly with 6-dimensional actions."""
prev = torch.zeros(6)
target = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
interp2.add(prev)
interp2.get()
interp2.add(target)
mid = interp2.get()
end = interp2.get()
torch.testing.assert_close(mid, target / 2)
torch.testing.assert_close(end, target)
# ====================== Simulated Control Loop Tests ======================
def test_control_loop_produces_correct_action_count():
"""Test N policy actions with multiplier M yields 1 + (N-1)*M robot commands."""
multiplier = 3
n_policy_actions = 5
interp = ActionInterpolator(multiplier=multiplier)
robot_commands = 0
for i in range(n_policy_actions):
action = torch.tensor([float(i)])
if interp.needs_new_action():
interp.add(action)
while True:
a = interp.get()
if a is None:
break
robot_commands += 1
expected = 1 + (n_policy_actions - 1) * multiplier
assert robot_commands == expected
def test_control_loop_monotonic_increase():
"""Test actions [0, 1, 2, 3] with multiplier=2 produce monotonically increasing values."""
interp = ActionInterpolator(multiplier=2)
all_values = []
for i in range(4):
interp.add(torch.tensor([float(i)]))
while True:
a = interp.get()
if a is None:
break
all_values.append(a.item())
for i in range(1, len(all_values)):
assert all_values[i] >= all_values[i - 1]
# ====================== ActionQueue + ActionInterpolator Integration Tests ======================
def _make_chunk(n_steps: int, action_dim: int = 2, offset: float = 0.0) -> torch.Tensor:
"""Create a simple action chunk: each row is [offset + step_idx, offset + step_idx]."""
return torch.arange(n_steps, dtype=torch.float32).unsqueeze(1).expand(-1, action_dim) + offset
def test_queue_interpolator_consumption_rate_matches_base_fps():
"""Test queue.get() is called at base fps rate, not multiplied fps."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=3)
chunk = _make_chunk(10)
queue.merge(chunk, chunk.clone(), real_delay=0)
queue_gets = 0
control_ticks = 0
while True:
if interp.needs_new_action():
if queue.empty():
break
action = queue.get()
if action is None:
break
interp.add(action)
queue_gets += 1
result = interp.get()
if result is not None:
control_ticks += 1
assert queue_gets == 10
assert control_ticks == 1 + 9 * 3
def test_queue_interpolator_leftover_decreases_only_on_queue_get():
"""Test get_left_over() shrinks only on queue.get(), not on interpolator sub-steps."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=3)
chunk = _make_chunk(6)
queue.merge(chunk, chunk.clone(), real_delay=0)
assert interp.needs_new_action()
interp.add(queue.get())
leftover_after_first_get = queue.get_left_over()
assert leftover_after_first_get is not None
assert len(leftover_after_first_get) == 5
interp.get()
assert len(queue.get_left_over()) == 5
interp.add(queue.get())
assert len(queue.get_left_over()) == 4
for _ in range(3):
assert interp.get() is not None
assert len(queue.get_left_over()) == 4
def test_queue_interpolator_processed_leftover_tracks_queue_index():
"""Test get_processed_left_over() reflects queue's last_index, not interpolator state."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
original = _make_chunk(8, offset=0.0)
processed = _make_chunk(8, offset=100.0)
queue.merge(original, processed, real_delay=0)
left = queue.get_processed_left_over()
assert len(left) == 8
for _ in range(3):
if interp.needs_new_action():
action = queue.get()
if action is not None:
interp.add(action)
interp.get()
proc_left = queue.get_processed_left_over()
orig_left = queue.get_left_over()
assert proc_left is not None and orig_left is not None
assert len(proc_left) == len(orig_left)
assert proc_left[0, 0].item() >= 100.0
assert orig_left[0, 0].item() < 100.0
def test_queue_interpolator_merge_resets_queue_but_interpolator_keeps_prev():
"""Test queue merge doesn't affect interpolator's prev, enabling smooth transitions."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
chunk1 = torch.tensor([[0.0], [2.0], [4.0], [6.0], [8.0]])
queue.merge(chunk1, chunk1.clone(), real_delay=0)
consumed = []
for _ in range(5):
if interp.needs_new_action():
a = queue.get()
if a is not None:
interp.add(a)
r = interp.get()
if r is not None:
consumed.append(r.item())
assert interp.needs_new_action()
assert consumed[-1] == pytest.approx(4.0)
idx_before = queue.get_action_index()
chunk2 = torch.tensor([[10.0], [12.0], [14.0]])
queue.merge(chunk2, chunk2.clone(), real_delay=0, action_index_before_inference=idx_before)
first_action = queue.get()
assert first_action is not None
interp.add(first_action)
first_from_new = interp.get()
assert first_from_new is not None
assert first_from_new.item() == pytest.approx(7.0)
def test_queue_interpolator_reset_does_not_affect_queue():
"""Test interpolator reset leaves queue state untouched."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
chunk = _make_chunk(5)
queue.merge(chunk, chunk.clone(), real_delay=0)
interp.add(queue.get())
interp.get()
interp.add(queue.get())
interp.get()
interp.get()
assert queue.qsize() == 3
interp.reset()
assert queue.qsize() == 3
assert len(queue.get_left_over()) == 3
interp.add(queue.get())
result = interp.get()
assert result is not None
assert queue.qsize() == 2
def test_queue_interpolator_no_interpolation_1_to_1():
"""Test multiplier=1 produces exactly 1 robot command per queue.get()."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=1)
chunk = _make_chunk(5)
queue.merge(chunk, chunk.clone(), real_delay=0)
robot_commands = 0
while not queue.empty():
if interp.needs_new_action():
action = queue.get()
if action is not None:
interp.add(action)
result = interp.get()
if result is not None:
robot_commands += 1
assert robot_commands == 5
def test_queue_interpolator_delay_skips_stale_actions():
"""Test merge with delay correctly skips stale actions for the interpolator."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
chunk1 = _make_chunk(10)
queue.merge(chunk1, chunk1.clone(), real_delay=0)
for _ in range(5):
if interp.needs_new_action():
a = queue.get()
if a is not None:
interp.add(a)
interp.get()
assert queue.get_action_index() == 3
chunk2 = _make_chunk(10, offset=100.0)
queue.merge(chunk2, chunk2.clone(), real_delay=3, action_index_before_inference=0)
first_action = queue.get()
assert first_action is not None
torch.testing.assert_close(first_action, torch.tensor([103.0, 103.0]))

View File

@@ -20,7 +20,7 @@ from functools import wraps
import pytest
import torch
from lerobot import available_cameras, available_microphones, available_motors, available_robots
from lerobot import available_cameras, available_motors, available_robots
from lerobot.utils.device_utils import auto_select_torch_device
from lerobot.utils.import_utils import is_package_available
@@ -34,10 +34,6 @@ TEST_CAMERA_TYPES = []
for camera_type in available_cameras:
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
TEST_MICROPHONE_TYPES = []
for microphone_type in available_microphones:
TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)]
TEST_MOTOR_TYPES = []
for motor_type in available_motors:
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
@@ -46,9 +42,6 @@ for motor_type in available_motors:
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
# Microphone indices used for connecting physical microphones
MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0))
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
DYNAMIXEL_MOTORS = {
"shoulder_pan": [1, "xl430-w250"],

View File

@@ -37,14 +37,6 @@ def mock_rerun(monkeypatch):
def __init__(self, value):
self.value = float(value)
@staticmethod
def columns(scalars):
return DummyScalarsColumn(scalars)
class DummyScalarsColumn:
def __init__(self, values):
self.values = values
class DummyImage:
def __init__(self, arr):
self.arr = arr
@@ -55,46 +47,12 @@ def mock_rerun(monkeypatch):
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_columns(key, indexes, columns, **kwargs):
calls.append((key, columns, kwargs))
def dummy_time_column(timeline, timestamp):
return timestamp
def dummy_set_time(timeline, timestamp):
return None
class DummyTimeSeriesView:
def __call__(self, origin, plot_legend=None):
return None
class DummySpatial2DView:
def __call__(self, origin):
return None
class DummyGrid:
def __call__(self, *args):
return None
class DummyPlotLegend:
def __call__(self, visible=True):
return None
dummy_rr = SimpleNamespace(
Scalars=DummyScalar,
Image=DummyImage,
log=dummy_log,
TimeColumn=dummy_time_column,
send_columns=dummy_send_columns,
set_time=dummy_set_time,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=SimpleNamespace(
TimeSeriesView=DummyTimeSeriesView,
Spatial2DView=DummySpatial2DView,
Grid=DummyGrid,
PlotLegend=DummyPlotLegend,
),
)
# Inject fake module into sys.modules
@@ -129,7 +87,7 @@ def _kwargs_for(calls, key):
raise KeyError(f"Key {key} not found in calls: {calls}")
def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
vu, calls = mock_rerun
# Build EnvTransition dict
@@ -137,8 +95,6 @@ def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
# Multiple channels audio data should be split into separate channels and logged as rr.Scalars.columns
"observation.audio": np.zeros((100, 2), dtype=np.float32),
}
act = {
"action.throttle": 0.7,
@@ -161,27 +117,25 @@ def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
# - action.throttle -> Scalars
# - action.vector_0, action.vector_1 -> Scalars
expected_keys = {
"data/" + f"{OBS_STATE}.temperature",
f"{OBS_STATE}.temperature",
"observation.camera",
"data/action.throttle",
"data/action.vector_0",
"data/action.vector_1",
"audio/observation.audio_channel_0",
"audio/observation.audio_channel_1",
"action.throttle",
"action.vector_0",
"action.vector_1",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"data/{OBS_STATE}.temperature")
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert temp_obj.value == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "data/action.throttle")
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert throttle_obj.value == pytest.approx(0.7)
v0 = _obj_for(calls, "data/action.vector_0")
v1 = _obj_for(calls, "data/action.vector_1")
v0 = _obj_for(calls, "action.vector_0")
v1 = _obj_for(calls, "action.vector_1")
assert type(v0).__name__ == "DummyScalar"
assert type(v1).__name__ == "DummyScalar"
assert v0.value == pytest.approx(1.0)
@@ -193,14 +147,6 @@ def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# Check audio handling: split channels + rr.Scalars.columns
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
assert audio_obj_0.values.shape == (100,)
assert audio_obj_1.values.shape == (100,)
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
vu, calls = mock_rerun
@@ -211,8 +157,6 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
"temp": 1.5,
# Already HWC image => should stay as-is
"img": np.zeros((5, 6, 3), dtype=np.uint8),
# Multiple channels audio data should be split into separate channels
"audio": np.zeros((100, 2), dtype=np.float32),
"none": None, # should be skipped
}
act_plain = {
@@ -226,24 +170,22 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
# Expected keys with auto-prefixes
expected = {
"data/observation.temp",
"observation.temp",
"observation.img",
"data/action.throttle",
"data/action.vec_0",
"data/action.vec_1",
"data/action.vec_2",
"audio/observation.audio_channel_0",
"audio/observation.audio_channel_1",
"action.throttle",
"action.vec_0",
"action.vec_1",
"action.vec_2",
}
logged = set(_keys(calls))
assert logged == expected
# Scalars
t = _obj_for(calls, "data/observation.temp")
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert t.value == pytest.approx(1.5)
throttle = _obj_for(calls, "data/action.throttle")
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert throttle.value == pytest.approx(0.3)
@@ -255,39 +197,25 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
# Vectors
for i, val in enumerate([9, 8, 7]):
o = _obj_for(calls, f"data/action.vec_{i}")
o = _obj_for(calls, f"action.vec_{i}")
assert type(o).__name__ == "DummyScalar"
assert o.value == pytest.approx(val)
# Audio
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
assert audio_obj_0.values.shape == (100,)
assert audio_obj_1.values.shape == (100,)
def test_log_rerun_data_kwargs_only(mock_rerun):
vu, calls = mock_rerun
vu.log_rerun_data(
observation={
"observation.temp": 10.0,
"observation.gray": np.zeros((8, 8, 1), dtype=np.uint8),
"observation.audio": np.zeros((100, 2), dtype=np.float32),
},
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)
keys = set(_keys(calls))
assert "data/observation.temp" in keys
assert "observation.temp" in keys
assert "observation.gray" in keys
assert "data/action.a" in keys
assert "audio/observation.audio_channel_0" in keys
assert "audio/observation.audio_channel_1" in keys
assert "action.a" in keys
temp = _obj_for(calls, "data/observation.temp")
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert temp.value == pytest.approx(10.0)
@@ -296,13 +224,6 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
a = _obj_for(calls, "data/action.a")
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert a.value == pytest.approx(1.0)
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
assert audio_obj_0.values.shape == (100,)
assert audio_obj_1.values.shape == (100,)