Compare commits

..

1 Commits

Author SHA1 Message Date
AdilZouitine
ab94626b92 fix normalization for dtype 2025-08-03 18:07:08 +02:00
89 changed files with 245 additions and 9340 deletions

View File

@@ -30,7 +30,7 @@ pytest -sx tests/test_stuff.py::test_something
```
```bash
lerobot-train --some.option=true
python -m lerobot.scripts.train --some.option=true
```
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR

View File

@@ -29,8 +29,8 @@ on:
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-gpu:latest
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-cpu:latest
# Ensures that only the latest commit is built, canceling older runs.
concurrency:

View File

@@ -44,7 +44,7 @@ test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval
test-act-ete-train:
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=act \
--policy.dim_model=64 \
--policy.n_action_steps=20 \
@@ -68,12 +68,12 @@ test-act-ete-train:
--output_dir=tests/outputs/act/
test-act-ete-train-resume:
lerobot-train \
python -m lerobot.scripts.train \
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
--resume=true
test-act-ete-eval:
lerobot-eval \
python -m lerobot.scripts.eval \
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=aloha \
@@ -82,7 +82,7 @@ test-act-ete-eval:
--eval.batch_size=1
test-diffusion-ete-train:
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=diffusion \
--policy.down_dims='[64,128,256]' \
--policy.diffusion_step_embed_dim=32 \
@@ -106,7 +106,7 @@ test-diffusion-ete-train:
--output_dir=tests/outputs/diffusion/
test-diffusion-ete-eval:
lerobot-eval \
python -m lerobot.scripts.eval \
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=pusht \
@@ -115,7 +115,7 @@ test-diffusion-ete-eval:
--eval.batch_size=1
test-tdmpc-ete-train:
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=tdmpc \
--policy.device=$(DEVICE) \
--policy.push_to_hub=false \
@@ -137,7 +137,7 @@ test-tdmpc-ete-train:
--output_dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval:
lerobot-eval \
python -m lerobot.scripts.eval \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=xarm \
@@ -148,7 +148,7 @@ test-tdmpc-ete-eval:
test-smolvla-ete-train:
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=smolvla \
--policy.n_action_steps=20 \
--policy.chunk_size=20 \
@@ -171,7 +171,7 @@ test-smolvla-ete-train:
--output_dir=tests/outputs/smolvla/
test-smolvla-ete-eval:
lerobot-eval \
python -m lerobot.scripts.eval \
--policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=aloha \

View File

@@ -6,7 +6,7 @@
<div align="center">
[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml?query=branch%3Amain)
[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nighty.yml?query=branch%3Amain)
[![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE)
[![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/)
@@ -101,9 +101,6 @@
## Installation
LeRobot works with Python 3.10+ and PyTorch 2.2+.
### Environment Setup
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
@@ -127,21 +124,10 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
### Install LeRobot 🤗
#### From Source
First, clone the repository and navigate into the directory:
Install 🤗 LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
```
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
```bash
pip install -e .
pip install lerobot
```
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
@@ -159,34 +145,6 @@ For instance, to install 🤗 LeRobot with aloha and pusht, use:
pip install -e ".[aloha, pusht]"
```
### Installation from PyPI
**Core Library:**
Install the base package with:
```bash
pip install lerobot
```
_This installs only the default dependencies._
**Extra Features:**
To install additional functionality, use one of the following:
```bash
pip install 'lerobot[all]' # All available features
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
pip install 'lerobot[feetech]' # Feetech motor support
```
_Replace `[...]` with your desired features._
**Available Tags:**
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
### Weights & Biases
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
```bash
@@ -276,7 +234,7 @@ Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
```bash
lerobot-eval \
python -m lerobot.scripts.eval \
--policy.path=lerobot/diffusion_pusht \
--env.type=pusht \
--eval.batch_size=10 \
@@ -288,10 +246,10 @@ lerobot-eval \
Note: After training your own policy, you can re-evaluate the checkpoints with:
```bash
lerobot-eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
python -m lerobot.scripts.eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
```
See `lerobot-eval --help` for more instructions.
See `python -m lerobot.scripts.eval --help` for more instructions.
### Train your own policy
@@ -303,7 +261,7 @@ A link to the wandb logs for the run will also show up in yellow in your termina
\<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/wandb.png" alt="WandB logs example"\>
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `lerobot-eval --help` for more instructions.
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions.
#### Reproduce state-of-the-art (SOTA)
@@ -311,7 +269,7 @@ We provide some pretrained policies on our [hub page](https://huggingface.co/ler
You can reproduce their training by loading the config from their run. Simply running:
```bash
lerobot-train --config_path=lerobot/diffusion_pusht
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
```
reproduces SOTA results for Diffusion Policy on the PushT task.
@@ -353,7 +311,7 @@ If you want, you can cite this work with:
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}

View File

@@ -1,26 +0,0 @@
#!/usr/bin/env python
"""Simple script to check buffer naming in the transformed model."""
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Load the model with strict=False to see what buffers we have
print("Loading model...")
policy = PI0Policy.from_pretrained("pepijn223/pi0_libero_lerobot", strict=False)
# Check what buffer keys exist
state_dict = policy.state_dict()
buffer_keys = [k for k in state_dict.keys() if "buffer" in k]
normalize_keys = [k for k in state_dict.keys() if "normalize" in k]
print("\nAll buffer keys:")
for key in buffer_keys:
print(f" {key}")
print("\nAll normalize keys:")
for key in normalize_keys:
print(f" {key}")
print("\nAll keys (first 20):")
for i, key in enumerate(state_dict.keys()):
if i < 20:
print(f" {key}")

View File

@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Install system dependencies and uv (as root)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& mv /root/.local/bin/uv /usr/local/bin/uv \

View File

@@ -35,14 +35,10 @@
title: Koch v1.1
- local: lekiwi
title: LeKiwi
- local: reachy2
title: Reachy 2
title: "Robots"
- sections:
- local: notebooks
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
title: "Resources"
- sections:
- local: contributing

View File

@@ -9,7 +9,7 @@ To instantiate a camera, you need a camera identifier. This identifier might cha
To find the camera indices of the cameras plugged into your system, run the following script:
```bash
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras
```
The output will look something like this if you have two cameras connected:

View File

@@ -1,71 +0,0 @@
# Feetech Motor Firmware Update
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
## Prerequisites
- Windows computer (Feetech software is only available for Windows)
- Feetech motor control board
- USB cable to connect the control board to your computer
- Feetech motors connected to the control board
## Step 1: Download Feetech Software
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
2. Download the latest version of the Feetech debugging software (FD)
3. Install the software on your Windows computer
## Step 2: Hardware Setup
1. Connect your Feetech motors to the motor control board
2. Connect the motor control board to your Windows computer via USB cable
3. Ensure power is supplied to the motors
## Step 3: Configure Connection
1. Launch the Feetech debugging software
2. Select the correct COM port from the port dropdown menu
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
4. Click "Open" to establish communication with the control board
## Step 4: Scan for Motors
1. Once connected, click the "Search" button to detect all connected motors
2. The software will automatically discover and list all motors on the bus
3. Each motor will appear with its ID number
## Step 5: Update Firmware
For each motor you want to update:
1. **Select the motor** from the list by clicking on it
2. **Click on Upgrade tab**:
3. **Click on Online button**:
- If an potential firmware update is found, it will be displayed in the box
4. **Click on Upgrade button**:
- The update progress will be displayed
## Step 6: Verify Update
1. After the update completes, the software should automatically refresh the motor information
2. Verify that the firmware version has been updated to the expected version
## Important Notes
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
## Bonus: Motor Debugging on Linux/macOS
For debugging purposes only, you can use the open-source Feetech Debug Tool:
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
### Installation Instructions
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
**Limitations:**
- This tool is for debugging and parameter adjustment only
- Firmware updates must still be done on Windows with official Feetech software

View File

@@ -412,7 +412,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
To train the classifier, use the `train.py` script with your configuration:
```bash
lerobot-train --config_path path/to/reward_classifier_train_config.json
python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_config.json
```
**Deploying and Testing the Model**
@@ -458,7 +458,7 @@ The reward classifier will automatically provide rewards based on the visual inp
3. **Train the classifier**:
```bash
lerobot-train --config_path src/lerobot/configs/reward_classifier_train_config.json
python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json
```
4. **Test the classifier**:

View File

@@ -19,7 +19,7 @@ pip install -e ".[hopejr]"
Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton:
```bash
lerobot-find-port
python -m lerobot.find_port
```
This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts.
@@ -31,7 +31,7 @@ Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibrati
### 1.1 Calibrate Robot Hand
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=blue \
@@ -81,7 +81,7 @@ Once you have set the appropriate boundaries for all joints, click "Save" to sav
### 1.2 Calibrate Teleoperator Glove
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--teleop.type=homunculus_glove \
--teleop.port=/dev/tty.usbmodem11201 \
--teleop.id=red \
@@ -120,7 +120,7 @@ Once calibration is complete, the system will save the calibration to `/Users/yo
### 1.3 Calibrate Robot Arm
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--robot.type=hope_jr_arm \
--robot.port=/dev/tty.usbserial-1110 \
--robot.id=white
@@ -146,7 +146,7 @@ Use the calibration interface to set the range boundaries for each joint. Move e
### 1.4 Calibrate Teleoperator Exoskeleton
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--teleop.type=homunculus_arm \
--teleop.port=/dev/tty.usbmodem11201 \
--teleop.id=black
@@ -178,7 +178,7 @@ Due to global variable conflicts in the Feetech middleware, teleoperation for ar
### Hand
```bash
lerobot-teleoperate \
python -m lerobot.teleoperate \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=blue \
@@ -194,7 +194,7 @@ lerobot-teleoperate \
### Arm
```bash
lerobot-teleoperate \
python -m lerobot.teleoperate \
--robot.type=hope_jr_arm \
--robot.port=/dev/tty.usbserial-1110 \
--robot.id=white \
@@ -214,7 +214,7 @@ Record, Replay and Train with Hope-JR is still experimental.
This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings).
```bash
lerobot-record \
python -m lerobot.record \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
@@ -236,7 +236,7 @@ lerobot-record \
### Replay
```bash
lerobot-replay \
python -m lerobot.replay \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
@@ -248,7 +248,7 @@ lerobot-replay \
### Train
```bash
lerobot-train \
python -m lerobot.scripts.train \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--policy.type=act \
--output_dir=outputs/train/hopejr_hand \
@@ -263,7 +263,7 @@ lerobot-train \
This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino).
```bash
lerobot-record \
python -m lerobot.record \
--robot.type=hope_jr_hand \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \

View File

@@ -45,7 +45,7 @@ Note that the `id` associated with a robot is used to store the calibration file
<hfoptions id="teleoperate_so101">
<hfoption id="Command">
```bash
lerobot-teleoperate \
python -m lerobot.teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
@@ -101,7 +101,7 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
<hfoptions id="teleoperate_koch_camera">
<hfoption id="Command">
```bash
lerobot-teleoperate \
python -m lerobot.teleoperate \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
@@ -174,7 +174,7 @@ Now you can record a dataset. To record 5 episodes and upload your dataset to th
<hfoptions id="record">
<hfoption id="Command">
```bash
lerobot-record \
python -m lerobot.record \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem585A0076841 \
--robot.id=my_awesome_follower_arm \
@@ -294,7 +294,7 @@ dataset.push_to_hub()
#### Dataset upload
Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. `https://huggingface.co/datasets/${HF_USER}/so101_test`) that you can obtain by running:
Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
```bash
echo https://huggingface.co/datasets/${HF_USER}/so101_test
@@ -376,7 +376,7 @@ You can replay the first episode on your robot with either the command below or
<hfoptions id="replay">
<hfoption id="Command">
```bash
lerobot-replay \
python -m lerobot.replay \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
@@ -428,10 +428,10 @@ Your robot should replicate movements similar to those you recorded. For example
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--dataset.repo_id=${HF_USER}/so101_test \
--policy.type=act \
--output_dir=outputs/train/act_so101_test \
@@ -444,7 +444,7 @@ lerobot-train \
Let's explain the command:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
@@ -453,7 +453,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
@@ -490,7 +490,7 @@ You can use the `record` script from [`lerobot/record.py`](https://github.com/hu
<hfoptions id="eval">
<hfoption id="Command">
```bash
lerobot-record \
python -m lerobot.record \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \

View File

@@ -96,10 +96,10 @@ If you uploaded your dataset to the hub you can [visualize your dataset online](
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--dataset.repo_id=${HF_USER}/il_gym \
--policy.type=act \
--output_dir=outputs/train/il_sim_test \
@@ -111,7 +111,7 @@ lerobot-train \
Let's explain the command:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.

View File

@@ -1,6 +1,15 @@
# Installation
## Environment Setup
## Install LeRobot
Currently only available from source.
Download our source code:
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
```
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
@@ -31,49 +40,12 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
## Install LeRobot 🤗
### From Source
First, clone the repository and navigate into the directory:
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
```
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
Install 🤗 LeRobot:
```bash
pip install -e .
```
### Installation from PyPI
**Core Library:**
Install the base package with:
```bash
pip install lerobot
```
_This installs only the default dependencies._
**Extra Features:**
To install additional functionality, use one of the following:
```bash
pip install 'lerobot[all]' # All available features
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
pip install 'lerobot[feetech]' # Feetech motor support
```
_Replace `[...]` with your desired features._
**Available Tags:**
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.

View File

@@ -31,7 +31,7 @@ pip install -e ".[dynamixel]"
To find the port for each bus servo adapter, run this script:
```bash
lerobot-find-port
python -m lerobot.find_port
```
<hfoptions id="example">
@@ -98,7 +98,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
<hfoption id="Command">
```bash
lerobot-setup-motors \
python -m lerobot.setup_motors \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
```
@@ -174,7 +174,7 @@ Do the same steps for the leader arm but modify the command or script accordingl
<hfoption id="Command">
```bash
lerobot-setup-motors \
python -m lerobot.setup_motors \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step
```
@@ -211,7 +211,7 @@ Run the following command or API example to calibrate the follower arm:
<hfoption id="Command">
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
@@ -249,7 +249,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
<hfoption id="Command">
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name

View File

@@ -60,7 +60,7 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba
To find the port for each bus servo adapter, run this script:
```bash
lerobot-find-port
python -m lerobot.find_port
```
<hfoptions id="example">
@@ -116,7 +116,7 @@ The instructions for configuring the motors can be found in the SO101 [docs](./s
You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7)
```bash
lerobot-setup-motors \
python -m lerobot.setup_motors \
--robot.type=lekiwi \
--robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step
```
@@ -174,7 +174,7 @@ The calibration process is very important because it allows a neural network tra
Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--robot.type=lekiwi \
--robot.id=my_awesome_kiwi # <- Give the robot a unique name
```
@@ -193,7 +193,7 @@ Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the
<hfoption id="Command">
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name

View File

@@ -1,288 +0,0 @@
# Reachy 2
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
## Teleoperate Reachy 2
Currently, there are two ways to teleoperate Reachy 2:
- Pollen Robotics VR teleoperation (not included in LeRobot).
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
## Reachy 2 Simulation
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
1. Install [Docker Engine](https://docs.docker.com/engine/).
2. Run (for MuJoCo):
```
docker run --rm -it \
--name reachy \
--privileged \
--network host \
--ipc host \
--device-cgroup-rule='c 189:* rwm' \
--group-add audio \
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
-e DISPLAY="$DISPLAY" \
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
-v /dev:/dev \
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
--entrypoint /package/launch.sh \
pollenrobotics/reachy2_core:1.7.5.9_deploy \
start_rviz:=true start_sdk_server:=true mujoco:=true
```
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
>
> ```
> docker run --rm -it \
> --name reachy \
> --privileged \
> --network host \
> --ipc host \
> --device-cgroup-rule='c 189:* rwm' \
> --group-add audio \
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
> -e DISPLAY="$DISPLAY" \
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
> -v /dev:/dev \
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
> --entrypoint /package/launch.sh \
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
> start_rviz:=true start_sdk_server:=true mujoco:=true
> ```
## Setup
### Prerequisites
- On your robot, check the **service images** meet the minimum versions:
- **reachy2-core >= 1.7.5.2**
- **webrtc >= 2.0.1.1**
Then, if you want to use VR teleoperation:
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
Use version **>=v1.2.0**
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
### Install LeRobot
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
Install LeRobot with Reachy 2 dependencies:
```bash
pip install -e ".[reachy2]"
```
### (Optional but recommended) Install pollen_data_acquisition_server
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. Youre free to develop custom clients to manage sessions to your needs.
In your LeRobot environment, install the server from source:
```bash
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
cd pollen_data_acquisition_server
pip install -e .
```
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
## Step 1: Recording
### Get Reachy 2 IP address
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
### Launch recording
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robots motions.
- **Using LeRobots record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, its only for motion control.
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
```bash
python -m pollen_data_acquisition_server.server
```
Then get into the teleoperation application and choose "Data acquisition session".
You can finally setup your session by following the screens displayed.
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
### Option 2: Using lerobot.record
Reachy 2 is fully supported by LeRobots recording features.
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
**Example: start a recording without the mobile base:**
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.id=r2-0000 \
--robot.use_external_commands=true \
--robot.with_mobile_base=false \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
#### Specific Options
**Extended setup overview (all options included):**
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=true \
--robot.with_mobile_base=true \
--robot.with_l_arm=true \
--robot.with_r_arm=true \
--robot.with_neck=true \
--robot.with_antennas=true \
--robot.with_left_teleop_camera=true \
--robot.with_right_teleop_camera=true \
--robot.with_torso_camera=false \
--robot.disable_torque_on_disconnect=false \
--robot.max_relative_target=5.0 \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.use_present_position=false \
--teleop.with_mobile_base=false \
--teleop.with_l_arm=true \
--teleop.with_r_arm=true \
--teleop.with_neck=true \
--teleop.with_antennas=true \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
##### `--robot.use_external_commands`
Determine whether LeRobot robot.send_action() sends commands to the robot.
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
##### `--teleop.use_present_position`
Determine whether the teleoperator reads the goal or present position of the robot.
Must be set to true if a compliant Reachy 2 is used to control another one.
##### Use the relevant parts
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
To avoid this, you can exclude specific parts from recording and replay using:
````
--robot.with_<part>=false
```,
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
It determine whether the corresponding part is recorded in the observations. True if not set.
By default, **all parts are recorded**.
The same per-part mechanism is available in `reachy2_teleoperator` as well.
````
--teleop.with\_<part>
```
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
Determine whether the corresponding part is recorded in the actions. True if not set.
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
##### Use the relevant cameras
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
```
--robot.with_left_teleop_camera=<true|false>
--robot.with_right_teleop_camera=<true|false>
--robot.with_torso_camera=<true|false>
````
## Step 2: Replay
Make sure the robot is configured with the same parts as the dataset:
```bash
python -m lerobot.replay \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=false \
--robot.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.episode=0
--display_data=true
````
## Step 3: Train
```bash
python -m lerobot.scripts.train \
--dataset.repo_id=pollen_robotics/record_test \
--policy.type=act \
--output_dir=outputs/train/reachy2_test \
--job_name=reachy2 \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=pollen_robotics/record_test_policy
```
## Step 4: Evaluate
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--display_data=false \
--dataset.repo_id=pollen_robotics/eval_record_test \
--dataset.single_task="Evaluate reachy2 policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
```

View File

@@ -54,7 +54,7 @@ If you don't have a gpu device, you can train using our notebook on [![Google Co
Pass your dataset to the training script using `--dataset.repo_id`. If you want to test your installation, run the following command where we use one of the datasets we collected for the [SmolVLA Paper](https://huggingface.co/papers/2506.01844).
```bash
cd lerobot && lerobot-train \
cd lerobot && python -m lerobot.scripts.train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=${HF_USER}/mydataset \
--batch_size=64 \
@@ -73,7 +73,7 @@ cd lerobot && lerobot-train \
Fine-tuning is an art. For a complete overview of the options for finetuning, run
```bash
lerobot-train --help
python -m lerobot.scripts.train --help
```
<p align="center">
@@ -97,7 +97,7 @@ Similarly for when recording an episode, it is recommended that you are logged i
Once you are logged in, you can run inference in your setup by doing:
```bash
lerobot-record \
python -m lerobot.record \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id

View File

@@ -26,7 +26,7 @@ Unlike the SO-101, the motor connectors are not easily accessible once the arm i
To find the port for each bus servo adapter, run this script:
```bash
lerobot-find-port
python -m lerobot.find_port
```
<hfoptions id="example">
@@ -93,7 +93,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
<hfoption id="Command">
```bash
lerobot-setup-motors \
python -m lerobot.setup_motors \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
```
@@ -168,7 +168,7 @@ Do the same steps for the leader arm.
<hfoptions id="setup_motors">
<hfoption id="Command">
```bash
lerobot-setup-motors \
python -m lerobot.setup_motors \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
```
@@ -568,7 +568,7 @@ Run the following command or API example to calibrate the follower arm:
<hfoption id="Command">
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
@@ -606,7 +606,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
<hfoption id="Command">
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name

View File

@@ -162,7 +162,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted:
```bash
lerobot-find-port
python -m lerobot.find_port
```
<hfoptions id="example">
@@ -240,7 +240,7 @@ Connect the usb cable from your computer and the power supply to the follower ar
<hfoption id="Command">
```bash
lerobot-setup-motors \
python -m lerobot.setup_motors \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
```
@@ -316,7 +316,7 @@ Do the same steps for the leader arm.
<hfoption id="Command">
```bash
lerobot-setup-motors \
python -m lerobot.setup_motors \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
```
@@ -353,7 +353,7 @@ Run the following command or API example to calibrate the follower arm:
<hfoption id="Command">
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
@@ -402,7 +402,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
<hfoption id="Command">
```bash
lerobot-calibrate \
python -m lerobot.calibrate \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name

View File

@@ -62,7 +62,7 @@ By default, every field takes its default value specified in the dataclass. If a
Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--dataset.repo_id=lerobot/pusht \
--policy.type=diffusion \
--env.type=pusht
@@ -77,7 +77,7 @@ Let's break this down:
Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=act \
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
--env.type=aloha \
@@ -90,7 +90,7 @@ We now want to train a different policy for aloha on another task. We'll change
Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=act \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
--env.type=aloha \
@@ -127,7 +127,7 @@ Now, let's assume that we want to reproduce the run just above. That run has pro
We can then simply load the config values from this file using:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
--output_dir=outputs/train/act_aloha_transfer_2
```
@@ -137,7 +137,7 @@ lerobot-train \
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
--output_dir=outputs/train/act_aloha_transfer_2
--policy.n_action_steps=80
@@ -148,7 +148,7 @@ lerobot-train \
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
```bash
lerobot-train --config_path=lerobot/diffusion_pusht
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
```
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
@@ -160,7 +160,7 @@ Being able to resume a training run is important in case it crashed or aborted f
Let's reuse the command from the previous run and add a few more options:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=act \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
--env.type=aloha \
@@ -179,7 +179,7 @@ INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
--resume=true
```
@@ -190,7 +190,7 @@ Another reason for which you might want to resume a run is simply to extend trai
You could double the number of steps of the previous run with:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
--resume=true \
--steps=200000
@@ -224,7 +224,7 @@ In addition to the features currently in Draccus, we've added a special `.path`
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
--env.type=aloha \
@@ -270,7 +270,7 @@ We'll summarize here the main use cases to remember from this tutorial.
#### Train a policy from scratch CLI
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=act \ # <- select 'act' policy
--env.type=pusht \ # <- select 'pusht' environment
--dataset.repo_id=lerobot/pusht # <- train on this dataset
@@ -279,7 +279,7 @@ lerobot-train \
#### Train a policy from scratch - config file + CLI
```bash
lerobot-train \
python -m lerobot.scripts.train \
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
--policy.n_action_steps=80 # <- you may still override values
```
@@ -287,7 +287,7 @@ lerobot-train \
#### Resume/continue a training run
```bash
lerobot-train \
python -m lerobot.scripts.train \
--config_path=checkpoint/pretrained_model/ \
--resume=true \
--steps=200000 # <- you can change some training parameters
@@ -296,7 +296,7 @@ lerobot-train \
#### Fine-tuning
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
--env.type=aloha \

View File

@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
Example:
```shell
lerobot-replay \
python -m lerobot.replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \

View File

@@ -1,347 +0,0 @@
#!/usr/bin/env python
"""Script for Pi0 pretrained policy inference and Hub upload."""
import argparse
from datetime import datetime
import numpy as np
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Set seed
torch.manual_seed(42)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Pi0 policy inference and Hub upload")
parser.add_argument(
"--source-model-id",
type=str,
default="pepijn223/pi0_libero_lerobot",
help="Source model repository ID on Hugging Face Hub",
)
parser.add_argument(
"--dataset-id", type=str, default="pepijn223/libero", help="Dataset repository ID on Hugging Face Hub"
)
parser.add_argument(
"--output-model-id",
type=str,
required=True,
help="Output model repository ID to upload to (e.g., 'your-username/pi0-libero-fixed')",
)
parser.add_argument(
"--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"], help="Device to run inference on"
)
parser.add_argument("--episode", type=int, default=0, help="Episode index to load from dataset")
parser.add_argument(
"--sample-idx", type=int, default=10, help="Sample index within episode to use for inference"
)
parser.add_argument("--private", action="store_true", help="Make the uploaded model private")
parser.add_argument(
"--commit-message", type=str, default=None, help="Custom commit message for the upload"
)
return parser.parse_args()
def _inject_normalization_stats(policy: PI0Policy, dataset_meta: LeRobotDatasetMetadata, key_mapping: dict):
"""Recreate normalization layers with proper stats from the dataset."""
from lerobot.policies.normalize import Normalize, Unnormalize
# Convert numpy stats to the format expected by normalization layers and remap keys
stats = {}
for dataset_key, stat_dict in dataset_meta.stats.items():
# Use mapped key if available, otherwise use original key
policy_key = key_mapping.get(dataset_key, dataset_key)
stats[policy_key] = {
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
for stat_type, stat_array in stat_dict.items()
}
print(f"Available stats keys: {list(stats.keys())}")
print(
f"Policy expects keys: input={list(policy.config.input_features.keys())}, output={list(policy.config.output_features.keys())}"
)
# Recreate normalization layers with proper stats
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, stats
)
# Replace the normalization layers on the policy
policy.normalize_inputs = normalize_inputs
policy.normalize_targets = normalize_targets
policy.unnormalize_outputs = unnormalize_outputs
print("Normalization layers recreated with dataset stats.")
def configure_policy_features(policy: PI0Policy, dataset: LeRobotDataset):
"""Configure policy input and output features based on dataset metadata."""
print(f"Dataset features: {list(dataset.meta.features.keys())}")
# Create a proper mapping from dataset keys to policy keys
dataset_to_policy_mapping = {}
# Handle images
if "image" in dataset.meta.features:
dataset_to_policy_mapping["image"] = "observation.images.image"
if "wrist_image" in dataset.meta.features:
dataset_to_policy_mapping["wrist_image"] = "observation.images.image2"
# Handle state
if "state" in dataset.meta.features:
dataset_to_policy_mapping["state"] = "observation.state"
# Handle actions
if "actions" in dataset.meta.features:
dataset_to_policy_mapping["actions"] = "action"
print(f"Key mapping: {dataset_to_policy_mapping}")
# Clear existing input features and reconfigure with proper mapping
policy.config.input_features = {}
policy.config.output_features = {}
# Map visual features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key in ["image", "wrist_image"]:
feature_info = dataset.meta.features[dataset_key]
# Convert HWC to CHW format and resize
shape = (3, 224, 224) # Pi0 expects CHW format
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.VISUAL, shape=shape)
# Map state features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key == "state":
feature_info = dataset.meta.features[dataset_key]
shape = tuple(feature_info["shape"])
policy.config.input_features[policy_key] = PolicyFeature(type=FeatureType.STATE, shape=shape)
# Map action features
for dataset_key, policy_key in dataset_to_policy_mapping.items():
if dataset_key == "actions":
feature_info = dataset.meta.features[dataset_key]
shape = tuple(feature_info["shape"])
policy.config.output_features[policy_key] = PolicyFeature(type=FeatureType.ACTION, shape=shape)
print(f"Policy input_features: {list(policy.config.input_features.keys())}")
print(f"Policy output_features: {list(policy.config.output_features.keys())}")
print(f"Policy image_features: {list(policy.config.image_features.keys())}")
print(f"Policy action_feature: {policy.config.action_feature}")
return dataset_to_policy_mapping
def fix_buffer_naming(policy: PI0Policy):
"""Fix buffer naming issues in the loaded policy state dict."""
print("Fixing normalization buffer naming issues...")
state_dict = policy.state_dict()
corrected_state_dict = {}
fixes_applied = 0
for key, value in state_dict.items():
new_key = key
# Fix buffer naming: buffer_observation_state_mean -> buffer_observation_state.mean
if "buffer_observation_state_mean" in key:
new_key = key.replace("buffer_observation_state_mean", "buffer_observation_state.mean")
fixes_applied += 1
print(f" Fixed: {key} -> {new_key}")
elif "buffer_observation_state_std" in key:
new_key = key.replace("buffer_observation_state_std", "buffer_observation_state.std")
fixes_applied += 1
print(f" Fixed: {key} -> {new_key}")
# Remove image buffers that aren't expected (they cause conflicts)
elif "buffer_observation_image_mean" in key or "buffer_observation_image_std" in key:
print(f" Removed unexpected buffer: {key}")
continue # Skip this buffer
corrected_state_dict[new_key] = value
# Add missing action buffers with dummy values (will be replaced by dataset stats)
missing_buffers = [
"normalize_targets.buffer_action.mean",
"normalize_targets.buffer_action.std",
"unnormalize_outputs.buffer_action.mean",
"unnormalize_outputs.buffer_action.std",
]
for buffer_key in missing_buffers:
if buffer_key not in corrected_state_dict:
# Use dummy values - these will be overwritten by proper dataset stats later
if "mean" in buffer_key:
corrected_state_dict[buffer_key] = torch.zeros(8) # Assume 8-dim action
else: # std
corrected_state_dict[buffer_key] = torch.ones(8) # Assume 8-dim action
fixes_applied += 1
print(f" Added missing buffer: {buffer_key}")
print(f"Applied {fixes_applied} buffer fixes")
# Load the corrected state dict back into the policy
policy.load_state_dict(corrected_state_dict)
return policy
def main():
"""Main function to run the Pi0 inference and upload."""
args = parse_args()
# Load pretrained Pi0 model directly from Hugging Face Hub
print(f"Loading pretrained Pi0 model from {args.source_model_id}...")
# Load with strict=False to allow missing/unexpected keys, then fix them manually
policy = PI0Policy.from_pretrained(args.source_model_id, strict=False)
policy = fix_buffer_naming(policy)
policy.eval()
policy.to(args.device)
# Load dataset and get a sample
print(f"Loading dataset: {args.dataset_id}")
dataset = LeRobotDataset(args.dataset_id, episodes=[args.episode])
meta: LeRobotDatasetMetadata = dataset.meta
sample = dataset[args.sample_idx]
# Configure policy features
key_mapping = configure_policy_features(policy, dataset)
# Inject normalization stats with proper key mapping
_inject_normalization_stats(policy, meta, key_mapping)
# Prepare batch for PI0 (handle temporal dimensions)
batch = {}
# Map dataset sample keys to policy keys
reverse_mapping = {v: k for k, v in key_mapping.items()}
for policy_key in policy.config.input_features:
# Find the corresponding dataset key
dataset_key = reverse_mapping.get(policy_key, policy_key)
if dataset_key in sample:
data = sample[dataset_key]
# Handle image data: convert from HWC to CHW and normalize
if policy_key.startswith("observation.images."):
if data.dim() == 3 and data.shape[-1] == 3: # HWC format
data = data.permute(2, 0, 1) # Convert to CHW
# Normalize to [0, 1] range if needed
if data.dtype == torch.uint8:
data = data.float() / 255.0
# Resize to expected size if needed
if data.shape[-2:] != (224, 224):
import torch.nn.functional as F # noqa: N812
data = F.interpolate(
data.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
)[0]
# Remove temporal dimension if present
if data.dim() > len(policy.config.input_features[policy_key].shape):
data = data[0]
batch[policy_key] = data.unsqueeze(0) # Add batch dimension
# Debug: print what's in the sample
print(f"Sample keys: {list(sample.keys())}")
print(f"Batch keys prepared: {list(batch.keys())}")
# Pi0 requires task description - add a default if not available
if "task" in sample:
batch["task"] = [sample["task"]] # Keep as list of strings
else:
print("No task in sample, using default task description")
batch["task"] = ["Complete the manipulation task"]
print(f"Task: {batch['task'][0]}")
print(f"Final batch keys: {list(batch.keys())}")
# Run inference
with torch.no_grad():
action = policy.select_action(batch)
print(f"Predicted action shape: {action.shape}")
print(f"Predicted action: {action.tolist()}")
print("✅ Pi0 pretrained inference completed successfully!")
# Upload to Hugging Face Hub
print(f"\n📤 Uploading model to Hugging Face Hub: {args.output_model_id}")
# Create commit message
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
commit_message = (
args.commit_message
or f"Pi0 model with injected normalization stats from {args.dataset_id} - {timestamp}"
)
# Update model configuration with dataset info
policy.config.push_to_hub = True
policy.config.repo_id = args.output_model_id
policy.config.private = args.private
# Add metadata about the adaptation
adaptation_info = {
"source_model": args.source_model_id,
"dataset_used": args.dataset_id,
"adaptation_date": timestamp,
"stats_injected": True,
"key_mapping": key_mapping,
"inference_test_passed": True,
"sample_action_shape": list(action.shape),
}
try:
# Push to hub
policy.push_to_hub(
repo_id=args.output_model_id,
private=args.private,
commit_message=commit_message,
create_pr=False,
)
# Also save the adaptation info as a separate file
import json
import os
import tempfile
from huggingface_hub import HfApi
api = HfApi()
# Create a temporary file with adaptation info
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(adaptation_info, f, indent=2)
temp_path = f.name
try:
api.upload_file(
path_or_fileobj=temp_path,
path_in_repo="adaptation_info.json",
repo_id=args.output_model_id,
commit_message=f"Add adaptation metadata - {timestamp}",
)
finally:
os.unlink(temp_path)
print(f"✅ Model successfully uploaded to: https://huggingface.co/{args.output_model_id}")
print("📋 Adaptation info:")
for key, value in adaptation_info.items():
print(f" {key}: {value}")
except Exception as e:
print(f"❌ Error uploading to Hub: {e}")
raise
if __name__ == "__main__":
main()

View File

@@ -1,704 +0,0 @@
import json
import os
import random
from datetime import datetime
import numpy as np
import torch
from huggingface_hub import hf_hub_download # noqa: E402
from safetensors.torch import load_file # noqa: E402
from transformers.model_debugging_utils import model_addition_debugger_context
from lerobot.configs.policies import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
RANDOM_SEED = 42 # Set to fixed value for reproducible results
def set_all_seeds(seed=42):
"""Set all random seeds for reproducible results."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
print(f"All random seeds set to {seed} for reproducible results (deterministic mode enabled)")
# Set seeds at the start
set_all_seeds(RANDOM_SEED)
config_model_path = "lerobot/pi0" # Use config from official model
official_model_path = "lerobot/pi0" # Official model
custom_model_path = "pepijn223/pi0_base_fp32" # Custom model to compare # pepijn223/pi0_base_fp32
device = "mps"
USE_FULL_TENSORS = True
SAVE_TENSORS_TO_DISK = False
# Model transformation and upload settings
SAVE_TRANSFORMED_MODEL = True # Set to True to save the transformed model
UPLOAD_TO_HUB = True # Set to True to upload to HuggingFace Hub
TRANSFORMED_MODEL_NAME = "pepijn223/pi0_base_fp32_lerobot_format" # Target repo name
COMMIT_MESSAGE = "Add transformed PI0 model with correct key format for lerobot"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
debug_path = os.path.join("debug_outputs", f"pi0_debug_direct_{timestamp}")
os.makedirs(debug_path, exist_ok=True)
print(f"Model debugging enabled - outputs will be saved to: {debug_path}")
# Download and load the config manually to avoid draccus parsing issues
config_file = hf_hub_download(repo_id=config_model_path, filename="config.json")
with open(config_file) as f:
config_dict = json.load(f)
# Remove the 'type' field that causes draccus issues
if "type" in config_dict:
config_dict.pop("type")
print("Removed 'type' field from config")
# Create shared PI0Config
print("Creating shared PI0Config...")
shared_config = PI0Config(**config_dict)
def load_policy_with_weights(
model_path: str, config: PI0Config, model_name: str, apply_transformations: bool = False
):
"""Load a policy with specified weights but shared config."""
print(f"\n=== Loading {model_name} from {model_path} ===")
# Set deterministic seed before creating the policy to ensure identical initialization
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
policy = PI0Policy(config)
# Download and load weights
model_file = hf_hub_download(repo_id=model_path, filename="model.safetensors")
print(f"Downloaded {model_name} weights to: {model_file}")
# Load state dict and apply transformations
print(f"Investigating safetensors file: {model_file}")
# First, check what's in the metadata
try:
from safetensors import safe_open
with safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
all_keys_in_file = f.keys()
print(f" Total keys in safetensors file: {len(list(all_keys_in_file))}")
# Check for embed_tokens in the file keys
embed_keys_in_file = [k for k in f.keys() if "embed_tokens" in k]
print(f" embed_tokens keys in safetensors: {embed_keys_in_file}")
if metadata:
print(f" Metadata exists: {list(metadata.keys()) if metadata else 'None'}")
except Exception as e:
print(f" Could not inspect safetensors file directly: {e}")
# Now load normally and see what we get
state_dict = load_file(model_file)
print(f" Keys loaded by load_file(): {len(state_dict)} keys")
# Check for embed_tokens in loaded state_dict
loaded_embed_keys = [k for k in state_dict.keys() if "embed_tokens" in k]
print(f" embed_tokens keys in loaded state_dict: {loaded_embed_keys}")
# Check if we need to add "model." prefix (for custom models that don't have it)
sample_key = next(iter(state_dict.keys()))
if not sample_key.startswith("model."):
print(f"Adding 'model.' prefix to all keys (detected format: {sample_key})")
state_dict = {f"model.{k}": v for k, v in state_dict.items()}
# IMPORTANT: Call PI0Policy._transform_state_dict_keys AFTER adding model. prefix
# This ensures tied weights logic can find the correct key pattern
transformed_state_dict = PI0Policy._transform_state_dict_keys(state_dict)
# Apply specific PaliGemma key transformations only for custom models
if apply_transformations:
print("Applying custom model key transformations...")
# First, let's debug what keys we actually have
all_keys = list(transformed_state_dict.keys())
sample_keys = all_keys[:10]
print(f"Sample keys to transform: {sample_keys}")
# Look for specific keys we need to transform and missing keys
embed_tokens_keys = [k for k in all_keys if "embed_tokens" in k]
embedding_keys = [k for k in all_keys if "embed" in k]
lm_head_keys = [k for k in all_keys if "lm_head" in k]
paligemma_keys = [
k for k in all_keys if "paligemma_with_expert.paligemma" in k and "gemma_expert" not in k
]
language_model_keys = [k for k in all_keys if "language_model" in k]
print(f"Found embed_tokens keys: {embed_tokens_keys}")
print(f"Found any embedding keys: {embedding_keys}")
print(f"Found lm_head keys: {lm_head_keys}")
print(
f"Found paligemma keys (non-expert): {paligemma_keys[:5]}{'...' if len(paligemma_keys) > 5 else ''}"
)
print(
f"Found language_model keys: {language_model_keys[:5]}{'...' if len(language_model_keys) > 5 else ''}"
)
print(f"Total keys in model: {len(all_keys)}")
# Check if the embed_tokens is in gemma_expert instead
gemma_expert_embed = [k for k in all_keys if "gemma_expert" in k and "embed_tokens" in k]
print(f"Found gemma_expert embed_tokens keys: {gemma_expert_embed}")
# Check what we're missing and what we actually have
expected_embed_key = "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
if expected_embed_key not in all_keys:
print(f" Missing expected embed_tokens key: {expected_embed_key}")
# Let's see what keys we actually have for debugging
print("Debugging: Looking for any embedding-related keys...")
all_embed_related = [k for k in all_keys if "embed" in k.lower()]
print(f"Keys containing 'embed': {all_embed_related}")
# Look for any keys that might contain embeddings
potential_embed_keys = [
k for k in all_keys if any(word in k for word in ["embed", "embedding", "token"])
]
print(f" Potential embedding keys: {potential_embed_keys}")
# Try to find a suitable replacement
if gemma_expert_embed:
print(f" Will try to copy from: {gemma_expert_embed[0]}")
else:
print(" No gemma_expert embed_tokens found either!")
# Check if there's an embed_tokens in the gemma_expert that we missed
gemma_keys = [k for k in all_keys if "gemma_expert" in k]
print(f" First 10 gemma_expert keys: {gemma_keys[:10]}")
# Check if there are any token-related keys in gemma_expert
token_keys = [k for k in all_keys if "gemma_expert" in k and "token" in k.lower()]
print(f" Gemma expert token-related keys: {token_keys}")
# Check for any keys that look like they might be embeddings
possible_embeds = [
k
for k in all_keys
if any(
pattern in k.lower() for pattern in ["embed_token", "embedding", "wte", "word_embed"]
)
]
print(f" Possible embedding alternatives: {possible_embeds}")
final_state_dict = {}
transformation_count = 0
for key, value in transformed_state_dict.items():
new_key = key
original_key = key
# Transform vision tower keys: ADD .model between paligemma and vision_tower
if "paligemma_with_expert.paligemma.vision_tower.vision_model" in new_key:
new_key = new_key.replace(
"paligemma_with_expert.paligemma.vision_tower.vision_model",
"paligemma_with_expert.paligemma.model.vision_tower.vision_model",
)
print(f"Transformed vision key: {original_key} -> {new_key}")
transformation_count += 1
# Transform multi_modal_projector keys: ADD .model between paligemma and multi_modal_projector
elif "paligemma_with_expert.paligemma.multi_modal_projector" in new_key:
new_key = new_key.replace(
"paligemma_with_expert.paligemma.multi_modal_projector",
"paligemma_with_expert.paligemma.model.multi_modal_projector",
)
print(f"Transformed multi_modal_projector key: {original_key} -> {new_key}")
transformation_count += 1
# NO transformation needed for language_model keys - they're already correct!
# The custom model already has: paligemma.model.language_model.* which is what we need
# NO transformation needed for lm_head - it should stay as paligemma.lm_head
final_state_dict[new_key] = value
print(f"Applied {transformation_count} key transformations")
transformed_state_dict = final_state_dict
else:
print("No transformations applied (official model format)")
# Debug: show what keys the policy expects vs what we have
policy_keys = set(policy.state_dict().keys())
provided_keys = set(transformed_state_dict.keys())
missing_in_provided = policy_keys - provided_keys
extra_in_provided = provided_keys - policy_keys
print(f"Policy expects {len(policy_keys)} keys, we provide {len(provided_keys)} keys")
if missing_in_provided:
print(
f" Missing from provided: {list(missing_in_provided)[:5]}{'...' if len(missing_in_provided) > 5 else ''}"
)
if extra_in_provided:
print(
f" Extra in provided: {list(extra_in_provided)[:5]}{'...' if len(extra_in_provided) > 5 else ''}"
)
# Load the weights into the policy
msg = policy.load_state_dict(transformed_state_dict, strict=True)
print(
f"{model_name} - Missing keys: {len(msg.missing_keys)}, Unexpected keys: {len(msg.unexpected_keys)}"
)
if msg.missing_keys:
print(
f" Actually missing keys: {list(msg.missing_keys)[:3]}{'...' if len(msg.missing_keys) > 3 else ''}"
)
if msg.unexpected_keys:
print(
f" Actually unexpected keys: {list(msg.unexpected_keys)[:3]}{'...' if len(msg.unexpected_keys) > 3 else ''}"
)
# Set deterministic mode and move to device
policy = policy.to(device)
policy.eval()
# Reset the policy to ensure identical internal state
policy.reset()
return policy
# Load both models with shared config
print("Loading both models with shared config...")
official_policy = load_policy_with_weights(
official_model_path, shared_config, "Official Model", apply_transformations=False
)
custom_policy = load_policy_with_weights(
custom_model_path, shared_config, "Custom Model", apply_transformations=True
)
print("\nBoth models loaded successfully!")
print(f"Shared config: {shared_config}")
print(f"Device: {device}")
# Configure input features for both policies since they're not set by default in pretrained models
def configure_policy_features(policy: PI0Policy):
"""Configure input and output features for a policy."""
policy.config.input_features[OBS_IMAGE] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Channel-first RGB image
)
policy.config.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(8,), # 8-dimensional state vector
)
policy.config.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(8,), # 8-dimensional action vector
)
# Add dummy normalization buffers to the policy (like openpi does with norm_stats)
if hasattr(policy, "normalize_inputs"):
# For observation.state (8-dim state vector)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_STATE.replace('.', '_')}_mean", torch.zeros(8, device=device)
)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_STATE.replace('.', '_')}_std", torch.ones(8, device=device)
)
# For observation.image (3x224x224 image)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_IMAGE.replace('.', '_')}_mean", torch.zeros(3, 224, 224, device=device)
)
policy.normalize_inputs.register_buffer(
f"buffer_{OBS_IMAGE.replace('.', '_')}_std", torch.ones(3, 224, 224, device=device)
)
print("Configuring features for both policies...")
configure_policy_features(official_policy)
configure_policy_features(custom_policy)
# Verify that the models have identical parameters
print("\n=== Model Parameter Comparison ===")
official_params = dict(official_policy.named_parameters())
custom_params = dict(custom_policy.named_parameters())
param_differences = []
for name in official_params.keys():
if name not in custom_params:
param_differences.append(f"Missing parameter in custom model: {name}")
else:
diff = torch.abs(official_params[name] - custom_params[name]).max().item()
if diff > 1e-8:
param_differences.append(f"Parameter {name}: max difference = {diff:.2e}")
for name in custom_params.keys():
if name not in official_params:
param_differences.append(f"Extra parameter in custom model: {name}")
if param_differences:
print("Parameter differences found:")
for diff in param_differences[:10]: # Show first 10 differences
print(f" {diff}")
if len(param_differences) > 10:
print(f" ... and {len(param_differences) - 10} more differences")
else:
print("All model parameters are identical!")
# Get the raw models for direct comparison
official_raw_model = official_policy.model
custom_raw_model = custom_policy.model
print("\n=== Model Details ===")
print(f"Official raw model type: {type(official_raw_model)}")
print(f"Custom raw model type: {type(custom_raw_model)}")
print(f"Official model device: {next(official_raw_model.parameters()).device}")
print(f"Custom model device: {next(custom_raw_model.parameters()).device}")
# Create lerobot-format input data (similar to DROID format from openpi example)
example = {
"joint_position": np.zeros(7, dtype=np.float32),
"gripper_position": np.array([0.0], dtype=np.float32),
"image": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": "pick up the object",
}
print(f"\nProvided input keys: {list(example.keys())}")
print("\nPreparing inputs for direct model call...")
# Apply input transformation (similar to openpi's policy._input_transform)
transformed_example = {}
# Combine joint and gripper positions into state
transformed_example[OBS_STATE] = np.concatenate([example["joint_position"], example["gripper_position"]])
transformed_example[OBS_IMAGE] = example["image"]
transformed_example["task"] = example["task"]
# Convert to PyTorch tensors and add batch dimension (as openpi example does)
# Device is already defined above, use the official model device for consistency
pytorch_inputs = {}
for key, value in transformed_example.items():
if isinstance(value, np.ndarray):
tensor_value = torch.from_numpy(value).to(device)
# Add batch dimension
if tensor_value.dim() > 0:
tensor_value = tensor_value.unsqueeze(0)
pytorch_inputs[key] = tensor_value
elif isinstance(value, str):
pytorch_inputs[key] = [value] # Convert to list format expected by policy
else:
pytorch_inputs[key] = value
# Convert image from HWC to CHW format for lerobot
if OBS_IMAGE in pytorch_inputs:
img = pytorch_inputs[OBS_IMAGE]
if img.dim() == 4 and img.shape[-1] == 3: # BHWC -> BCHW
img = img.permute(0, 3, 1, 2)
# Convert to float and normalize to [0, 1] range
img = img.float() / 255.0
pytorch_inputs[OBS_IMAGE] = img
print(f"Transformed input keys: {list(pytorch_inputs.keys())}")
for key, value in pytorch_inputs.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.shape} {value.dtype}")
else:
print(f" {key}: {type(value)} - {value}")
# Reset both policies (clears the action queue)
official_policy.reset()
custom_policy.reset()
# Prepare inputs using the official policy (both models should have same preprocessing)
print("Preparing inputs for both models...")
images, img_masks = official_policy.prepare_images(pytorch_inputs)
lang_tokens, lang_masks = official_policy.prepare_language(pytorch_inputs)
state = official_policy.prepare_state(pytorch_inputs)
print("Prepared inputs:")
print(f" Images: {len(images)} images")
print(f" Language tokens shape: {lang_tokens.shape}")
print(f" State shape: {state.shape}")
for i, img in enumerate(images):
print(f" Image {i} shape: {img.shape}")
for i, mask in enumerate(img_masks):
print(f" Image mask {i} shape: {mask.shape}")
# Compare both models with identical inputs
print("\n🚀 Running MODEL COMPARISON...")
# Force torch.no_grad for consistent comparison
with torch.no_grad():
# Ensure reproducible noise generation for both models
torch.manual_seed(RANDOM_SEED)
# Generate synthetic noise and time for the forward call
batch_size = 1
actions_shape = (
batch_size,
official_raw_model.config.n_action_steps,
official_raw_model.config.max_action_dim,
)
# Generate noise and time using direct PyTorch operations instead of model methods
# This avoids any potential model-specific randomness
torch.manual_seed(RANDOM_SEED)
noise = torch.normal(
mean=0.0,
std=1.0,
size=actions_shape,
dtype=torch.float32,
device=device,
)
# Generate time using the same distribution as PI0FlowMatching.sample_time
torch.manual_seed(RANDOM_SEED) # Reset for consistent time
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((batch_size,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
print("\n=== Generated Inputs ===")
print(f" Action shape: {actions_shape}")
print(f" Noise shape: {noise.shape}")
print(f" Time value: {time.item():.6f}")
print(f" Noise sample (first 5 values): {noise.flatten()[:5].tolist()}")
# Create dummy actions for forward pass (required for training forward)
dummy_actions = torch.zeros(actions_shape, dtype=torch.float32, device=device)
print("\n=== Running Forward Passes ===")
print("Running with model_addition_debugger_context for detailed analysis...")
# Create separate debug paths for each model
official_debug_path = os.path.join(debug_path, "official_model")
custom_debug_path = os.path.join(debug_path, "custom_model")
os.makedirs(official_debug_path, exist_ok=True)
os.makedirs(custom_debug_path, exist_ok=True)
# Set deterministic mode for forward pass
torch.manual_seed(RANDOM_SEED)
# Run official model with debugger
print("Running Official Model forward pass with debugger...")
with model_addition_debugger_context(
official_raw_model,
debug_path=official_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
official_loss = official_raw_model.forward(
images=images,
img_masks=img_masks,
lang_tokens=lang_tokens,
lang_masks=lang_masks,
state=state,
actions=dummy_actions,
noise=noise,
time=time,
)
# Reset seed before second forward pass to ensure any internal randomness is identical
torch.manual_seed(RANDOM_SEED)
# Run custom model with debugger
print("Running Custom Model forward pass with debugger...")
with model_addition_debugger_context(
custom_raw_model,
debug_path=custom_debug_path,
do_prune_layers=False, # Output ALL layers
use_repr=not SAVE_TENSORS_TO_DISK,
):
custom_loss = custom_raw_model.forward(
images=images,
img_masks=img_masks,
lang_tokens=lang_tokens,
lang_masks=lang_masks,
state=state,
actions=dummy_actions,
noise=noise,
time=time,
)
print(f"Official model debug outputs saved to: {official_debug_path}")
print(f"Custom model debug outputs saved to: {custom_debug_path}")
print("\n=== Output Comparison ===")
print(f"Official model loss shape: {official_loss.shape}")
print(f"Custom model loss shape: {custom_loss.shape}")
# Compare outputs
loss_diff = torch.abs(official_loss - custom_loss)
print("\n=== Detailed Comparison ===")
print("Loss difference stats:")
print(f" Mean absolute difference: {loss_diff.mean().item():.8f}")
print(f" Max absolute difference: {loss_diff.max().item():.8f}")
print(f" Min absolute difference: {loss_diff.min().item():.8f}")
print(f" Standard deviation of difference: {loss_diff.std().item():.8f}")
# Show some actual values for comparison
print("\nSample output values:")
print(f" Official model (first 5): {official_loss.flatten()[:5].tolist()}")
print(f" Custom model (first 5): {custom_loss.flatten()[:5].tolist()}")
print(f" Difference (first 5): {loss_diff.flatten()[:5].tolist()}")
# Determine if models are equivalent
are_equivalent = loss_diff.max().item() < 1e-6
print(f"\nModels are {'EQUIVALENT' if are_equivalent else 'DIFFERENT'}")
print(f" (Max difference: {loss_diff.max().item():.8f}, Threshold: 1e-6)")
print(f"\nDetailed debugging outputs saved to: {debug_path}")
# Save comparison results
comparison_results = {
"official_loss_stats": {
"shape": list(official_loss.shape),
"mean": official_loss.mean().item(),
"std": official_loss.std().item(),
"min": official_loss.min().item(),
"max": official_loss.max().item(),
},
"custom_loss_stats": {
"shape": list(custom_loss.shape),
"mean": custom_loss.mean().item(),
"std": custom_loss.std().item(),
"min": custom_loss.min().item(),
"max": custom_loss.max().item(),
},
"difference_stats": {
"mean_abs_diff": loss_diff.mean().item(),
"max_abs_diff": loss_diff.max().item(),
"min_abs_diff": loss_diff.min().item(),
"std_diff": loss_diff.std().item(),
"are_equivalent": are_equivalent,
},
}
comparison_file = os.path.join(debug_path, "model_comparison_results.json")
with open(comparison_file, "w") as f:
json.dump(comparison_results, f, indent=2)
print(f" Comparison results saved to: {comparison_file}")
# Save and upload transformed model if requested
if SAVE_TRANSFORMED_MODEL:
print("\nSaving Transformed Model...")
if are_equivalent:
print("Models are equivalent - proceeding with transformation and upload")
else:
print("Models are NOT equivalent, but proceeding with upload anyway")
print(f" Max difference: {loss_diff.max().item():.2e}")
print(" This might be useful for debugging or partial transformations")
# Create timestamp for README
transformation_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
# Use the already working custom policy as the base for transformation
print("Using already working custom policy as base for transformed model...")
# Deep copy the custom policy to create the transformed version
from copy import deepcopy
transformed_policy = deepcopy(custom_policy)
print("Custom policy copied successfully - no additional configuration needed")
# Save locally first
local_save_path = "./transformed_pi0_model"
print(f"Saving transformed model locally to: {local_save_path}")
transformed_policy.save_pretrained(local_save_path, safe_serialization=True)
# Save the tokenizer as well (required for complete model)
transformed_policy.language_tokenizer.save_pretrained(local_save_path)
# Create a README with transformation details
readme_content = f"""
# PI0 Model - LeRobot Compatible Format
This model is a transformed version of `{custom_model_path}` with key names corrected to match the official LeRobot PI0 format.
## Transformation Applied
The original model had a different key naming convention. This model applies the following transformations:
1. **Model prefix**: Added `model.` prefix to all parameter keys
2. **Tied weights**: Applied PI0Policy's built-in tied weights logic to create `embed_tokens.weight` from `lm_head.weight`
3. **Key structure**: Applied standard PI0 key transformations for compatibility
## Verification
{"This transformed model produces **identical outputs**" if are_equivalent else "This transformed model has **slightly different outputs**"} (max difference = {loss_diff.max().item():.2e}) compared to the official model `{official_model_path}` when tested with the same inputs.
{"**Models are EQUIVALENT** (difference < 1e-6)" if are_equivalent else "**Models are NOT equivalent** (difference >= 1e-6) - use with caution"}
## Usage
```python
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
# Load the model
policy = PI0Policy.from_pretrained("{TRANSFORMED_MODEL_NAME}")
# Use for inference
action = policy.select_action(observation_batch)
```
## Original Model
- **Source**: {custom_model_path}
- **Verified Against**: {official_model_path}
## Technical Details
- **Total Parameters**: {sum(p.numel() for p in transformed_policy.parameters()):,}
- **Model Type**: PI0FlowMatching with PaliGemma + Expert Gemma
- **Configuration**: Matches official PI0 configuration
"""
readme_path = os.path.join(local_save_path, "README.md")
with open(readme_path, "w") as f:
f.write(readme_content.strip())
print(f"Model saved locally to: {local_save_path}")
# Upload to HuggingFace Hub if requested
if UPLOAD_TO_HUB:
print(f"\nUploading to HuggingFace Hub: {TRANSFORMED_MODEL_NAME}")
try:
# Push to hub
transformed_policy.push_to_hub(
repo_id=TRANSFORMED_MODEL_NAME,
commit_message=COMMIT_MESSAGE,
private=False, # Make it public
safe_serialization=True,
)
print(f"Model successfully uploaded to: https://huggingface.co/{TRANSFORMED_MODEL_NAME}")
print("You can now use this model directly without any transformations!")
print("\n Usage:")
print(" from lerobot.policies.pi0.modeling_pi0 import PI0Policy")
print(f" policy = PI0Policy.from_pretrained('{TRANSFORMED_MODEL_NAME}')")
except Exception as upload_error:
print(f"Failed to upload to HuggingFace Hub: {upload_error}")
print(f"You can manually upload the model from: {local_save_path}")
print(" Or set UPLOAD_TO_HUB = False and upload later")
except Exception as e:
import traceback
print(f"Error saving transformed model: {str(e)}")
print("Full traceback:")
traceback.print_exc()
print("The model transformation logic works, but saving failed")
else:
print("\nModel transformation and upload disabled (SAVE_TRANSFORMED_MODEL = False)")

View File

@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.3.4"
version = "0.3.2"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -68,16 +68,15 @@ dependencies = [
"einops>=0.8.0",
"opencv-python-headless>=4.9.0",
"av>=14.2.0",
"torch>=2.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0",
"jsonlines>=4.0.0",
"packaging>=24.2",
"pynput>=1.7.7",
"pyserial>=3.5",
"wandb>=0.20.0",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
@@ -106,7 +105,6 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
reachy2 = ["reachy2_sdk>=1.0.14"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
@@ -142,7 +140,6 @@ all = [
"lerobot[gamepad]",
"lerobot[hopejr]",
"lerobot[lekiwi]",
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi0]",

View File

@@ -18,7 +18,7 @@ Helper to recalibrate your device (robot or teleoperator).
Example:
```shell
lerobot-calibrate \
python -m lerobot.calibrate \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue

View File

@@ -60,7 +60,7 @@ class OpenCVCamera(Camera):
or port changes, especially on Linux. Use the provided utility script to find
available camera indices or paths:
```bash
lerobot-find-cameras opencv
python -m lerobot.find_cameras opencv
```
The camera's default settings (FPS, resolution, color mode) are used unless
@@ -165,7 +165,8 @@ class OpenCVCamera(Camera):
self.videocapture.release()
self.videocapture = None
raise ConnectionError(
f"Failed to open {self}.Run `lerobot-find-cameras opencv` to find available cameras."
f"Failed to open {self}."
f"Run `python -m lerobot.find_cameras opencv` to find available cameras."
)
self._configure_capture_settings()

View File

@@ -1,16 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_reachy2_camera import Reachy2CameraConfig
from .reachy2_camera import Reachy2Camera

View File

@@ -1,78 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
class Reachy2CameraConfig(CameraConfig):
"""Configuration class for Reachy 2 camera devices.
This class provides configuration options for Reachy 2 cameras,
supporting both the teleop and depth cameras. It includes settings
for resolution, frame rate, color mode, and the selection of the cameras.
Example configurations:
```python
# Basic configurations
Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address="192.168.0.200", # IP address of the robot
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
) # Left teleop camera, 640x480 @ 15FPS
```
Attributes:
name: Name of the camera device. Can be "teleop" or "depth".
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
fps: Requested frames per second for the color stream.
width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
ip_address: IP address of the robot. Defaults to "localhost".
port: Port number for the camera server. Defaults to 50065.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
"""
name: str
image_type: str
color_mode: ColorMode = ColorMode.RGB
ip_address: str | None = "localhost"
port: int = 50065
# use_depth: bool = False
def __post_init__(self):
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
self.name == "depth" and self.image_type not in ["rgb", "depth"]
):
raise ValueError(
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)

View File

@@ -1,288 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
"""
import logging
import os
import platform
import time
from threading import Event, Lock, Thread
from typing import Any
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
from lerobot.errors import DeviceNotConnectedError
from ..camera import Camera
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
logger = logging.getLogger(__name__)
class Reachy2Camera(Camera):
"""
Manages Reachy 2 camera using Reachy 2 CameraManager.
This class provides a high-level interface to connect to, configure, and read
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
frame reading.
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
type (e.g., "left") to be specified in the configuration.
The camera's default settings (FPS, resolution, color mode) are used unless
overridden in the configuration.
"""
def __init__(self, config: Reachy2CameraConfig):
"""
Initializes the Reachy2Camera instance.
Args:
config: The configuration settings for the camera.
"""
super().__init__(config)
self.config = config
self.fps = config.fps
self.color_mode = config.color_mode
self.cam_manager: CameraManager | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
@property
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
elif self.config.name == "depth":
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
def connect(self, warmup: bool = True):
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
self.cam_manager.initialize_cameras()
logger.info(f"{self} connected.")
@staticmethod
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
"""
Detects available Reachy 2 cameras.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'name', 'stereo',
and the default profile properties (width, height, fps).
"""
initialized_cameras = []
camera_manager = CameraManager(host=ip_address, port=port)
for camera in [camera_manager.teleop, camera_manager.depth]:
if camera is None:
continue
height, width, _, _, _, _, _ = camera.get_parameters()
camera_info = {
"name": camera._cam_info.name,
"stereo": camera._cam_info.stereo,
"default_profile": {
"width": width,
"height": height,
"fps": 30,
},
}
initialized_cameras.append(camera_info)
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
color mode and applying any configured rotation.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
frame = None
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
else:
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.config.image_type == "left":
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
elif self.config.image_type == "right":
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.config.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()[0]
elif self.config.image_type == "rgb":
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8)
if self.config.color_mode == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
while not self.stop_event.is_set():
try:
color_image = self.read()
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
(height, width, channels), processed according to configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
return frame
def disconnect(self):
"""
Stops the background read thread (if running).
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
if self.cam_manager is not None:
self.cam_manager.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -51,7 +51,7 @@ class RealSenseCamera(Camera):
Use the provided utility script to find available camera indices and default profiles:
```bash
lerobot-find-cameras realsense
python -m lerobot.find_cameras realsense
```
A `RealSenseCamera` instance requires a configuration object specifying the
@@ -176,7 +176,8 @@ class RealSenseCamera(Camera):
self.rs_profile = None
self.rs_pipeline = None
raise ConnectionError(
f"Failed to open {self}.Run `lerobot-find-cameras realsense` to find available cameras."
f"Failed to open {self}."
"Run `python -m lerobot.find_cameras realsense` to find available cameras."
) from e
self._configure_capture_settings()

View File

@@ -37,14 +37,8 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
from .realsense.camera_realsense import RealSenseCamera
cameras[key] = RealSenseCamera(cfg)
elif cfg.type == "reachy2_camera":
from .reachy2_camera.reachy2_camera import Reachy2Camera
cameras[key] = Reachy2Camera(cfg)
else:
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return cameras

View File

@@ -27,7 +27,6 @@ from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
@@ -120,8 +119,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def robot_state_feature(self) -> PolicyFeature | None:
for ft_name, ft in self.input_features.items():
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
for _, ft in self.input_features.items():
if ft.type is FeatureType.STATE:
return ft
return None
@@ -138,8 +137,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def action_feature(self) -> PolicyFeature | None:
for ft_name, ft in self.output_features.items():
if ft.type is FeatureType.ACTION and ft_name == ACTION:
for _, ft in self.output_features.items():
if ft.type is FeatureType.ACTION:
return ft
return None

View File

@@ -825,8 +825,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
if not episode_data:
episode_buffer = self.episode_buffer
else:
episode_buffer = episode_data
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)

View File

@@ -13,22 +13,20 @@
# limitations under the License.
"""
This script will help you download any LeRobot dataset from the hub, convert it to the latest format, and
upload it to your own repository. It will:
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
- Download the dataset from any source repository
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Update codebase_version in `info.json` to the latest version
- Create proper version tags
- Push the converted dataset to your specified destination repository
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
Usage:
```bash
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
--source-repo-id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
--dest-repo-id=your-username/libero_spatial_converted \
--episodes=0,1,2,3,4
--repo-id=aliberts/koch_tutorial
```
"""
@@ -39,8 +37,8 @@ import logging
from huggingface_hub import HfApi
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, write_info
from lerobot.datasets.v21.convert_stats import convert_stats
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
V20 = "v2.0"
V21 = "v2.1"
@@ -56,133 +54,48 @@ class SuppressWarnings:
def convert_dataset(
source_repo_id: str,
dest_repo_id: str | None = None,
episodes: str | None = None,
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
force_cache_sync: bool = True,
):
"""
Download a dataset from source_repo_id, convert it, and upload to dest_repo_id.
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
Args:
source_repo_id: Source repository to download from
dest_repo_id: Destination repository to upload to (defaults to source_repo_id)
episodes: Comma-separated list of episode indices to include (e.g. "0,1,2,3")
branch: Branch to upload to
num_workers: Number of workers for stats computation
force_cache_sync: Whether to force cache synchronization
"""
if dest_repo_id is None:
dest_repo_id = source_repo_id
# Parse episodes list if provided
episode_list = None
if episodes:
try:
episode_list = [int(ep.strip()) for ep in episodes.split(",")]
print(f"Loading episodes: {episode_list}")
except ValueError as e:
raise ValueError(
f"Invalid episodes format '{episodes}'. Use comma-separated integers like '0,1,2,3'"
) from e
print(f"Downloading dataset from: {source_repo_id}")
# Try to load the dataset with different approaches to handle versioning issues
dataset = None
load_attempts = [
{"revision": None}, # Try latest first
{"revision": V20}, # Try v2.0
{"revision": "main"}, # Try main branch
]
for attempt in load_attempts:
try:
print(f"Attempting to load with revision: {attempt['revision']}")
with SuppressWarnings():
dataset = LeRobotDataset(
source_repo_id, episodes=episode_list, force_cache_sync=force_cache_sync, **attempt
)
print("Successfully loaded dataset!")
break
except Exception as e:
print(f"Failed with revision {attempt['revision']}: {e}")
continue
if dataset is None:
raise RuntimeError(f"Could not load dataset {source_repo_id} with any revision")
# Clean up old stats if present
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
print("Removed existing episodes_stats.jsonl")
print("Converting stats to new format...")
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
# Update dataset info
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
print(f"Updated codebase_version to {CODEBASE_VERSION}")
# Change repo_id for destination if different
if dest_repo_id != source_repo_id:
print(f"Changing repository from {source_repo_id} to {dest_repo_id}")
dataset.repo_id = dest_repo_id
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
print(f"Pushing converted dataset to: {dest_repo_id}")
dataset.push_to_hub(branch=branch, tag_version=False)
# Clean up old stats.json file locally and on hub
if (dataset.root / STATS_PATH).is_file():
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
print("Removed local stats.json file")
hub_api = HfApi()
try:
if hub_api.file_exists(
repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset"
)
print("Removed stats.json from hub")
except Exception as e:
print(f"Warning: Could not remove stats.json from hub: {e}")
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
# Create version tag
try:
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
print(f"Created tag {CODEBASE_VERSION} for {dest_repo_id}")
except Exception as e:
print(f"Warning: Could not create tag: {e}")
print(f"✅ Successfully converted and uploaded dataset to {dest_repo_id}")
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download, convert, and re-upload LeRobot datasets with proper versioning"
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--source-repo-id",
"--repo-id",
type=str,
required=True,
help="Source repository identifier to download from (e.g. 'IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot')",
)
parser.add_argument(
"--dest-repo-id",
type=str,
default=None,
help="Destination repository identifier to upload to. Defaults to source-repo-id if not specified.",
)
parser.add_argument(
"--episodes",
type=str,
default=None,
help="Comma-separated list of episode indices to include (e.g. '0,1,2,3,4'). If not specified, all episodes are included.",
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
@@ -196,22 +109,6 @@ if __name__ == "__main__":
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
parser.add_argument(
"--no-cache-sync",
action="store_true",
help="Skip forcing cache synchronization (faster but may use cached data)",
)
args = parser.parse_args()
# Convert args to match function signature
convert_args = {
"source_repo_id": args.source_repo_id,
"dest_repo_id": args.dest_repo_id,
"episodes": args.episodes,
"branch": args.branch,
"num_workers": args.num_workers,
"force_cache_sync": not args.no_cache_sync,
}
convert_dataset(**convert_args)
convert_dataset(**vars(args))

View File

@@ -20,7 +20,7 @@ Helper to find the camera devices available in your system.
Example:
```shell
lerobot-find-cameras
python -m lerobot.find_cameras
```
"""

View File

@@ -18,7 +18,7 @@ Helper to find the USB port associated with your MotorsBus.
Example:
```shell
lerobot-find-port
python -m lerobot.find_port
```
"""

View File

@@ -107,8 +107,6 @@ X_SERIES_ENCODINGS_TABLE = {
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],

View File

@@ -222,7 +222,7 @@ class MotorsBus(abc.ABC):
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
```bash
lerobot-find-port.py
python -m lerobot.find_port.py
>>> Finding all available ports for the MotorsBus.
>>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"]
>>> Remove the usb cable from your MotorsBus and press Enter when done.
@@ -446,7 +446,7 @@ class MotorsBus(abc.ABC):
except (FileNotFoundError, OSError, serial.SerialException) as e:
raise ConnectionError(
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
"\nTry running `lerobot-find-port`\n"
"\nTry running `python -m lerobot.find_port`\n"
) from e
@abc.abstractmethod

View File

@@ -24,6 +24,7 @@ def create_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
@@ -60,8 +61,8 @@ def create_stats_buffers(
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
mean = torch.ones(shape, dtype=dtype) * torch.inf
std = torch.ones(shape, dtype=dtype) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
@@ -69,8 +70,8 @@ def create_stats_buffers(
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
min = torch.ones(shape, dtype=dtype) * torch.inf
max = torch.ones(shape, dtype=dtype) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
@@ -82,22 +83,22 @@ def create_stats_buffers(
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=dtype)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=dtype)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=dtype)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=dtype)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=dtype)
buffer["std"].data = stats[key]["std"].clone().to(dtype=dtype)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
buffer["min"].data = stats[key]["min"].clone().to(dtype=dtype)
buffer["max"].data = stats[key]["max"].clone().to(dtype=dtype)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
@@ -121,6 +122,7 @@ class Normalize(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
"""
Args:
@@ -144,7 +146,7 @@ class Normalize(nn.Module):
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -195,6 +197,7 @@ class Unnormalize(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
"""
Args:
@@ -219,7 +222,7 @@ class Unnormalize(nn.Module):
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -262,6 +265,7 @@ def _initialize_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
@@ -282,8 +286,8 @@ def _initialize_stats_buffers(
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
mean = torch.full(shape, torch.inf, dtype=dtype)
std = torch.full(shape, torch.inf, dtype=dtype)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
@@ -293,8 +297,8 @@ def _initialize_stats_buffers(
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
mean = mean_data.clone().to(dtype=dtype)
std = std_data.clone().to(dtype=dtype)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -303,15 +307,15 @@ def _initialize_stats_buffers(
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
min_val = torch.full(shape, torch.inf, dtype=dtype)
max_val = torch.full(shape, torch.inf, dtype=dtype)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
min_val = min_data.clone().to(dtype=dtype)
max_val = max_data.clone().to(dtype=dtype)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -330,12 +334,13 @@ class NormalizeBuffer(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
@@ -379,12 +384,13 @@ class UnnormalizeBuffer(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)

View File

@@ -30,7 +30,7 @@ pip install -e ".[pi0]"
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.path=lerobot/pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
@@ -38,7 +38,7 @@ lerobot-train \
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
pretrained with VLM default parameters before pi0 finetuning:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```

View File

@@ -25,14 +25,14 @@ Disclaimer: It is not expected to perform as well as the original implementation
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.path=lerobot/pi0fast_base \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of training the pi0+FAST neural network with from scratch:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=pi0fast \
--dataset.repo_id=danaaubakirova/koch_test
```

View File

@@ -28,7 +28,7 @@ pip install -e ".[smolvla]"
Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
@@ -38,7 +38,7 @@ lerobot-train \
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
and an action expert.
```bash
lerobot-train \
python -m lerobot.scripts.train \
--policy.type=smolvla \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
@@ -673,19 +673,19 @@ class VLAFlowMatching(nn.Module):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
def sample_noise(self, shape, device, dtype):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
dtype=dtype,
device=device,
)
return noise
def sample_time(self, bsize, device):
def sample_time(self, bsize, device, dtype):
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
time = time_beta * 0.999 + 0.001
return time
@@ -831,10 +831,10 @@ class VLAFlowMatching(nn.Module):
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
noise = self.sample_noise(actions.shape, actions.device, actions.dtype)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time = self.sample_time(actions.shape[0], actions.device, actions.dtype)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
@@ -868,10 +868,11 @@ class VLAFlowMatching(nn.Module):
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
dtype = state.dtype
if noise is None:
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
noise = self.sample_noise(actions_shape, device, dtype)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
@@ -888,18 +889,13 @@ class VLAFlowMatching(nn.Module):
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
dt = torch.tensor(dt, dtype=dtype, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
time = torch.tensor(1.0, dtype=dtype, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
v_t = self.denoise_step(prefix_pad_masks, past_key_values, x_t, expanded_time, dtype)
# Euler step
x_t += dt * v_t
time += dt
@@ -911,6 +907,7 @@ class VLAFlowMatching(nn.Module):
past_key_values,
x_t,
timestep,
dtype,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
@@ -936,6 +933,6 @@ class VLAFlowMatching(nn.Module):
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
suffix_out = suffix_out.to(dtype=dtype)
v_t = self.action_out_proj(suffix_out)
return v_t

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .device_processor import DeviceProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import VanillaObservationProcessor
from .pipeline import (
ActionProcessor,
DoneProcessor,
EnvTransition,
IdentityProcessor,
InfoProcessor,
ObservationProcessor,
ProcessorStep,
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionKey,
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
__all__ = [
"ActionProcessor",
"DeviceProcessor",
"DoneProcessor",
"EnvTransition",
"IdentityProcessor",
"InfoProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
"ObservationProcessor",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",
"RewardProcessor",
"RobotProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
]

View File

@@ -1,82 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@dataclass
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned.
"""
device: torch.device = "cpu"
def __post_init__(self):
self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
# Process observation tensors
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,331 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
"""Convert numpy arrays and other types to torch tensors."""
tensor_stats: dict[str, dict[str, Tensor]] = {}
for key, sub in stats.items():
tensor_stats[key] = {}
for stat_name, value in sub.items():
if isinstance(value, np.ndarray):
tensor_val = torch.from_numpy(value.astype(np.float32))
elif isinstance(value, torch.Tensor):
tensor_val = value.to(dtype=torch.float32)
elif isinstance(value, (int, float, list, tuple)):
tensor_val = torch.tensor(value, dtype=torch.float32)
else:
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
tensor_stats[key][stat_name] = tensor_val
return tensor_stats
@dataclass
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessor:
"""Normalizes observations and actions in a single processor step.
This processor handles normalization of both observation and action tensors
using either mean/std normalization or min/max scaling to a [-1, 1] range.
For each tensor key in the stats dictionary, the processor will:
- Use mean/std normalization if those statistics are provided: (x - mean) / std
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
The processor can be configured to normalize only specific keys by setting
the normalize_keys parameter.
"""
# Features and normalisation map are mandatory to match the design of normalize.py
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance.
stats: dict[str, dict[str, Any]] | None = None
# Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1).
normalize_keys: set[str] | None = None
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
normalize_keys: set[str] | None = None,
eps: float = 1e-8,
) -> NormalizerProcessor:
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design
pattern used in normalize.py.
"""
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_keys=normalize_keys,
eps=eps,
)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
# Convert statistics once so we avoid repeated numpy→Tensor conversions
# during runtime.
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
# Ensure *normalize_keys* is a set for fast look-ups and compare by
# value later when returning the configuration.
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation):
if observation is None:
return None
# Decide which keys should be normalised for this call.
if self.normalize_keys is not None:
keys_to_norm = self.normalize_keys
else:
# Use feature map to skip action keys.
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
return processed
def _normalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return (tensor - mean) / (std + self.eps)
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
config = {
"eps": self.eps,
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
if self.normalize_keys is not None:
# Serialise as a list for YAML / JSON friendliness
config["normalize_keys"] = sorted(self.normalize_keys)
return config
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def reset(self):
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor")
class UnnormalizerProcessor:
"""Inverse normalisation for observations and actions.
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
transform.
"""
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
return processed
def _unnormalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return tensor * std + mean
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return (tensor + 1) / 2 * (max_val - min_val) + min_val
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def reset(self):
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,157 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import einops
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessor(ObservationProcessor):
"""
Processes environment observations into the LeRobot format by handling both images and states.
Image processing:
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
- Adds a batch dimension if missing
- Supports single images and image dictionaries
State processing:
- Maps 'environment_state' to observation.environment_state
- Maps 'agent_pos' to observation.state
- Converts numpy arrays to tensors
- Adds a batch dimension if missing
"""
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
# Convert to tensor
img_tensor = torch.from_numpy(img)
# Add batch dimension if needed
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# Validate image format
_, h, w, c = img_tensor.shape
if not (c < h and c < w):
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
if img_tensor.dtype != torch.uint8:
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
# Convert to channel-first format
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
# Convert to float32 and normalize to [0, 1]
img_tensor = img_tensor.type(torch.float32) / 255.0
return img_tensor
def _process_observation(self, observation):
"""
Processes both image and state observations.
"""
processed_obs = observation.copy()
if "pixels" in processed_obs:
pixels = processed_obs.pop("pixels")
if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
for imgkey, img in imgs.items():
processed_obs[imgkey] = self._process_single_image(img)
if "environment_state" in processed_obs:
env_state_np = processed_obs.pop("environment_state")
env_state = torch.from_numpy(env_state_np).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
processed_obs[OBS_ENV_STATE] = env_state
if "agent_pos" in processed_obs:
agent_pos_np = processed_obs.pop("agent_pos")
agent_pos = torch.from_numpy(agent_pos_np).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs[OBS_STATE] = agent_pos
return processed_obs
def observation(self, observation):
return self._process_observation(observation)
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
- environment_state -> OBS_ENV_STATE,
- agent_pos -> OBS_STATE,
- observation.environment_state -> OBS_ENV_STATE,
- observation.agent_pos -> OBS_STATE
"""
exact_pairs = {
"pixels": OBS_IMAGE,
"environment_state": OBS_ENV_STATE,
"agent_pos": OBS_STATE,
}
prefix_pairs = {
"pixels.": f"{OBS_IMAGES}.",
}
for key in list(features.keys()):
matched_prefix = False
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if matched_prefix:
continue
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
if key in features:
features[new] = features.pop(key)
break
return features

File diff suppressed because it is too large Load Diff

View File

@@ -1,51 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
@dataclass
@ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessor(ObservationProcessor):
"""Rename processor that renames keys in the observation."""
rename_map: dict[str, str] = field(default_factory=dict)
def observation(self, observation):
processed_obs = {}
for key, value in observation.items():
if key in self.rename_map:
processed_obs[self.rename_map[key]] = value
else:
processed_obs[key] = value
return processed_obs
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
- Keys not in `rename_map` remain unchanged.
"""
return {self.rename_map.get(k, k): v for k, v in features.items()}

View File

@@ -18,7 +18,7 @@ Records a dataset. Actions for the robot can be either generated by teleoperatio
Example:
```shell
lerobot-record \
python -m lerobot.record \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
@@ -36,7 +36,7 @@ lerobot-record \
Example recording with bimanual so100:
```shell
lerobot-record \
python -m lerobot.record \
--robot.type=bi_so100_follower \
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
@@ -209,14 +209,7 @@ def record_loop(
(
t
for t in teleop
if isinstance(
t,
(
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
)
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
),
None,
)

View File

@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
Examples:
```shell
lerobot-replay \
python -m lerobot.replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
@@ -28,7 +28,7 @@ lerobot-replay \
Example replay with bimanual so100:
```shell
lerobot-replay \
python -m lerobot.replay \
--robot.type=bi_so100_follower \
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
@@ -55,7 +55,6 @@ from lerobot.robots import ( # noqa: F401
hope_jr,
koch_follower,
make_robot_from_config,
reachy2,
so100_follower,
so101_follower,
)

View File

@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
# Optional
left_arm_disable_torque_on_disconnect: bool = True
left_arm_max_relative_target: float | dict[str, float] | None = None
left_arm_max_relative_target: int | None = None
left_arm_use_degrees: bool = False
right_arm_disable_torque_on_disconnect: bool = True
right_arm_max_relative_target: float | dict[str, float] | None = None
right_arm_max_relative_target: int | None = None
right_arm_use_degrees: bool = False
# cameras (shared between both arms)

View File

@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -110,7 +110,6 @@ class KochFollower(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
@@ -121,6 +120,7 @@ class KochFollower(Robot):
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)

View File

@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)

View File

@@ -1,25 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_reachy2 import Reachy2RobotConfig
from .robot_reachy2 import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Robot,
)

View File

@@ -1,107 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.cameras.configs import ColorMode
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("reachy2")
@dataclass
class Reachy2RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors.
max_relative_target: float | None = None
# IP address of the Reachy 2 robot
ip_address: str | None = "localhost"
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
disable_torque_on_disconnect: bool = False
# Tag for external commands control
# Set to True if you use an external commands system to control the robot,
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
# If True, robot.send_action() will not send commands to the robot.
use_external_commands: bool = False
# Robot parts
# Set to False to not add the corresponding joints part to the robot list of joints.
# By default, all parts are set to True.
with_mobile_base: bool = True
with_l_arm: bool = True
with_r_arm: bool = True
with_neck: bool = True
with_antennas: bool = True
# Robot cameras
# Set to True if you want to use the corresponding cameras in the observations.
# By default, only the teleop cameras are used.
with_left_teleop_camera: bool = True
with_right_teleop_camera: bool = True
with_torso_camera: bool = False
cameras: dict[str, CameraConfig] = field(default_factory=dict)
def __post_init__(self) -> None:
# Add cameras with same ip_address as the robot
if self.with_left_teleop_camera:
self.cameras["teleop_left"] = Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_right_teleop_camera:
self.cameras["teleop_right"] = Reachy2CameraConfig(
name="teleop",
image_type="right",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_torso_camera:
self.cameras["torso_rgb"] = Reachy2CameraConfig(
name="depth",
image_type="rgb",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
super().__post_init__()
if not (
self.with_mobile_base
or self.with_l_arm
or self.with_r_arm
or self.with_neck
or self.with_antennas
):
raise ValueError(
"No Reachy2Robot part used.\n"
"At least one part of the robot must be set to True "
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
)

View File

@@ -1,230 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any
import numpy as np
from reachy2_sdk import ReachySDK
from lerobot.cameras.utils import make_cameras_from_configs
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .configuration_reachy2 import Reachy2RobotConfig
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw",
"neck_pitch.pos": "head.neck.pitch",
"neck_roll.pos": "head.neck.roll",
}
REACHY2_ANTENNAS_JOINTS = {
"l_antenna.pos": "head.l_antenna",
"r_antenna.pos": "head.r_antenna",
}
REACHY2_R_ARM_JOINTS = {
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
"r_wrist_roll.pos": "r_arm.wrist.roll",
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
"r_gripper.pos": "r_arm.gripper",
}
REACHY2_L_ARM_JOINTS = {
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
"l_wrist_roll.pos": "l_arm.wrist.roll",
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
"l_gripper.pos": "l_arm.gripper",
}
REACHY2_VEL = {
"mobile_base.vx": "vx",
"mobile_base.vy": "vy",
"mobile_base.vtheta": "vtheta",
}
class Reachy2Robot(Robot):
"""
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
"""
config_class = Reachy2RobotConfig
name = "reachy2"
def __init__(self, config: Reachy2RobotConfig):
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.use_external_commands = self.config.use_external_commands
self.reachy: None | ReachySDK = None
self.cameras = make_cameras_from_configs(config.cameras)
self.logs: dict[str, float] = {}
self.joints_dict: dict[str, str] = self._generate_joints_dict()
@property
def observation_features(self) -> dict[str, Any]:
return {**self.motors_features, **self.camera_features}
@property
def action_features(self) -> dict[str, type]:
return self.motors_features
@property
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
@property
def motors_features(self) -> dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
self.joints_dict.keys(),
float,
),
**dict.fromkeys(
REACHY2_VEL.keys(),
float,
),
}
else:
return dict.fromkeys(self.joints_dict.keys(), float)
@property
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = False) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
raise ConnectionError()
for cam in self.cameras.values():
cam.connect()
self.configure()
def configure(self) -> None:
if self.reachy is not None:
self.reachy.turn_on()
self.reachy.reset_default_limits()
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
def _generate_joints_dict(self) -> dict[str, str]:
joints = {}
if self.config.with_neck:
joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
joints.update(REACHY2_ANTENNAS_JOINTS)
return joints
def _get_state(self) -> dict[str, float]:
if self.reachy is not None:
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
return pos_dict
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
return {**pos_dict, **vel_dict}
else:
return {}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict: dict[str, Any] = {}
# Read Reachy 2 state
before_read_t = time.perf_counter()
obs_dict.update(self._get_state())
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
# Capture images from cameras
for cam_key, cam in self.cameras.items():
obs_dict[cam_key] = cam.async_read()
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if self.reachy is not None:
if not self.is_connected:
raise ConnectionError()
before_write_t = time.perf_counter()
vel = {}
goal_pos = {}
for key, val in action.items():
if key not in self.joints_dict:
if key not in REACHY2_VEL:
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
else:
vel[REACHY2_VEL[key]] = float(val)
else:
if not self.use_external_commands and self.config.max_relative_target is not None:
goal_pos[key] = float(val)
goal_present_pos = {
key: (
goal_pos[key],
self.reachy.joints[self.joints_dict[key]].present_position,
)
}
safe_goal_pos = ensure_safe_goal_position(
goal_present_pos, float(self.config.max_relative_target)
)
val = safe_goal_pos[key]
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
if self.config.with_mobile_base:
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
# We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
return action
def disconnect(self) -> None:
if self.reachy is not None:
for cam in self.cameras.values():
cam.disconnect()
if self.config.disable_torque_on_disconnect:
self.reachy.turn_off_smoothly()
self.reachy.disconnect()

View File

@@ -30,9 +30,9 @@ class SO100FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -161,11 +161,6 @@ class SO100Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
if motor == "gripper":
self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")

View File

@@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -157,13 +157,6 @@ class SO101Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
if motor == "gripper":
self.bus.write(
"Max_Torque_Limit", motor, 500
) # 50% of the max torque limit to avoid burnout
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")

View File

@@ -24,6 +24,11 @@ from ..config import RobotConfig
@RobotConfig.register_subclass("stretch3")
@dataclass
class Stretch3RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {

View File

@@ -61,10 +61,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .bi_so100_follower import BiSO100Follower
return BiSO100Follower(config)
elif config.type == "reachy2":
from .reachy2 import Reachy2Robot
return Reachy2Robot(config)
elif config.type == "mock_robot":
from tests.mocks.mock_robot import MockRobot
@@ -74,7 +70,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
def ensure_safe_goal_position(
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float]
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
) -> dict[str, float]:
"""Caps relative action target magnitude for safety."""

View File

@@ -141,10 +141,10 @@ python lerobot/scripts/control_robot.py \
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
python -m lerobot.scripts.train \
--dataset.repo_id=${HF_USER}/aloha_test \
--policy.type=act \
--output_dir=outputs/train/act_aloha_test \

View File

@@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig):
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: float | dict[str, float] = 5.0
max_relative_target: int | None = 5
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -21,7 +21,7 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di
for 10 episodes.
```
lerobot-eval \
python -m lerobot.scripts.eval \
--policy.path=lerobot/diffusion_pusht \
--env.type=pusht \
--eval.batch_size=10 \
@@ -32,7 +32,7 @@ lerobot-eval \
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
```
lerobot-eval \
python -m lerobot.scripts.eval \
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
--env.type=pusht \
--eval.batch_size=10 \

View File

@@ -302,6 +302,11 @@ class RobotClient:
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:

View File

@@ -18,7 +18,7 @@ Helper to set motor ids and baudrate.
Example:
```shell
lerobot-setup-motors \
python -m lerobot.setup_motors \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem575E0031751
```

View File

@@ -18,7 +18,7 @@ Simple script to control a robot from teleoperation.
Example:
```shell
lerobot-teleoperate \
python -m lerobot.teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
@@ -32,7 +32,7 @@ lerobot-teleoperate \
Example teleoperation with bimanual so100:
```shell
lerobot-teleoperate \
python -m lerobot.teleoperate \
--robot.type=bi_so100_follower \
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \

View File

@@ -88,7 +88,6 @@ class KochLeader(Teleoperator):
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
@@ -99,6 +98,7 @@ class KochLeader(Teleoperator):
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)

View File

@@ -1,25 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
from .reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Teleoperator,
)

View File

@@ -1,51 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("reachy2_teleoperator")
@dataclass
class Reachy2TeleoperatorConfig(TeleoperatorConfig):
# IP address of the Reachy 2 robot used as teleoperator
ip_address: str | None = "localhost"
# Whether to use the present position of the joints as actions
# if False, the goal position of the joints will be used
use_present_position: bool = False
# Which parts of the robot to use
with_mobile_base: bool = True
with_l_arm: bool = True
with_r_arm: bool = True
with_neck: bool = True
with_antennas: bool = True
def __post_init__(self):
if not (
self.with_mobile_base
or self.with_l_arm
or self.with_r_arm
or self.with_neck
or self.with_antennas
):
raise ValueError(
"No Reachy2Teleoperator part used.\n"
"At least one part of the robot must be set to True "
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
)

View File

@@ -1,164 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from reachy2_sdk import ReachySDK
from ..teleoperator import Teleoperator
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
logger = logging.getLogger(__name__)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw",
"neck_pitch.pos": "head.neck.pitch",
"neck_roll.pos": "head.neck.roll",
}
REACHY2_ANTENNAS_JOINTS = {
"l_antenna.pos": "head.l_antenna",
"r_antenna.pos": "head.r_antenna",
}
REACHY2_R_ARM_JOINTS = {
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
"r_wrist_roll.pos": "r_arm.wrist.roll",
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
"r_gripper.pos": "r_arm.gripper",
}
REACHY2_L_ARM_JOINTS = {
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
"l_wrist_roll.pos": "l_arm.wrist.roll",
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
"l_gripper.pos": "l_arm.gripper",
}
REACHY2_VEL = {
"mobile_base.vx": "vx",
"mobile_base.vy": "vy",
"mobile_base.vtheta": "vtheta",
}
class Reachy2Teleoperator(Teleoperator):
"""
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
"""
config_class = Reachy2TeleoperatorConfig
name = "reachy2_specific"
def __init__(self, config: Reachy2TeleoperatorConfig):
super().__init__(config)
self.config = config
self.reachy: None | ReachySDK = None
self.joints_dict: dict[str, str] = self._generate_joints_dict()
def _generate_joints_dict(self) -> dict[str, str]:
joints = {}
if self.config.with_neck:
joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
joints.update(REACHY2_ANTENNAS_JOINTS)
return joints
@property
def action_features(self) -> dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
self.joints_dict.keys(),
float,
),
**dict.fromkeys(
REACHY2_VEL.keys(),
float,
),
}
else:
return dict.fromkeys(self.joints_dict.keys(), float)
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = True) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
raise ConnectionError()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
def configure(self) -> None:
pass
def get_action(self) -> dict[str, float]:
start = time.perf_counter()
if self.reachy and self.is_connected:
if self.config.use_present_position:
joint_action = {
k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
}
else:
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return joint_action
if self.config.use_present_position:
vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
else:
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return {**joint_action, **vel_action}
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
def disconnect(self) -> None:
if self.reachy and self.is_connected:
self.reachy.disconnect()

View File

@@ -65,9 +65,5 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .bi_so100_leader import BiSO100Leader
return BiSO100Leader(config)
elif config.type == "reachy2_teleoperator":
from .reachy2_teleoperator import Reachy2Teleoperator
return Reachy2Teleoperator(config)
else:
raise ValueError(config.type)

View File

@@ -44,7 +44,7 @@ Below is the short version on how to train and run inference/eval:
### Train from scratch
```bash
lerobot-train \
python -m lerobot.scripts.train \
--dataset.repo_id=${HF_USER}/<dataset> \
--policy.type=act \
--output_dir=outputs/train/<desired_policy_repo_id> \
@@ -59,7 +59,7 @@ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
### Evaluate the policy/run inference
```bash
lerobot-record \
python -m lerobot.record \
--robot.type=so100_follower \
--dataset.repo_id=<hf_user>/eval_<dataset> \
--policy.path=<hf_user>/<desired_policy_repo_id> \

View File

@@ -17,9 +17,10 @@ import time
def busy_wait(seconds):
if platform.system() == "Darwin" or platform.system() == "Windows":
# On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick,
if platform.system() == "Darwin":
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
# but it consumes CPU cycles.
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
end_time = time.perf_counter() + seconds
while time.perf_counter() < end_time:
pass

View File

@@ -1,177 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig
from lerobot.errors import DeviceNotConnectedError
PARAMS = [
("teleop", "left"),
("teleop", "right"),
("depth", "rgb"),
# ("depth", "depth"), # Depth camera is not available yet
]
def _make_cam_manager_mock():
c = MagicMock(name="CameraManagerMock")
teleop = MagicMock(name="TeleopCam")
teleop.width = 640
teleop.height = 480
teleop.get_frame = MagicMock(
side_effect=lambda *_, **__: (
np.zeros((480, 640, 3), dtype=np.uint8),
time.time(),
)
)
depth = MagicMock(name="DepthCam")
depth.width = 640
depth.height = 480
depth.get_frame = MagicMock(
side_effect=lambda *_, **__: (
np.zeros((480, 640, 3), dtype=np.uint8),
time.time(),
)
)
c.is_connected.return_value = True
c.teleop = teleop
c.depth = depth
def _connect():
c.teleop = teleop
c.depth = depth
c.is_connected.return_value = True
def _disconnect():
c.teleop = None
c.depth = None
c.is_connected.return_value = False
c.connect = MagicMock(side_effect=_connect)
c.disconnect = MagicMock(side_effect=_disconnect)
# Mock methods
c.initialize_cameras = MagicMock()
return c
@pytest.fixture(
params=PARAMS,
# ids=["teleop-left", "teleop-right", "torso-rgb", "torso-depth"],
ids=["teleop-left", "teleop-right", "torso-rgb"],
)
def camera(request):
name, image_type = request.param
with (
patch(
"lerobot.cameras.reachy2_camera.reachy2_camera.CameraManager",
side_effect=lambda *a, **k: _make_cam_manager_mock(),
),
):
config = Reachy2CameraConfig(name=name, image_type=image_type)
cam = Reachy2Camera(config)
yield cam
if cam.is_connected:
cam.disconnect()
def test_connect(camera):
camera.connect()
assert camera.is_connected
camera.cam_manager.initialize_cameras.assert_called_once()
def test_read(camera):
camera.connect()
img = camera.read()
if camera.config.name == "teleop":
camera.cam_manager.teleop.get_frame.assert_called_once()
elif camera.config.name == "depth":
camera.cam_manager.depth.get_frame.assert_called_once()
assert isinstance(img, np.ndarray)
assert img.shape == (480, 640, 3)
def test_disconnect(camera):
camera.connect()
camera.disconnect()
assert not camera.is_connected
def test_async_read(camera):
camera.connect()
try:
img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray)
finally:
if camera.is_connected:
camera.disconnect()
def test_async_read_timeout(camera):
camera.connect()
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
def test_read_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
_ = camera.read()
def test_disconnect_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
camera.disconnect()
def test_async_read_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
_ = camera.async_read()
def test_wrong_camera_name():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
def test_wrong_image_type():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="teleop", image_type="rgb")
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="depth", image_type="left")
def test_wrong_color_mode():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="teleop", image_type="left", color_mode="wrong-color")

View File

@@ -19,7 +19,6 @@ import traceback
import pytest
from serial import SerialException
from lerobot.configs.types import FeatureType, PolicyFeature
from tests.utils import DEVICE
# Import fixture modules as plugins
@@ -28,7 +27,6 @@ pytest_plugins = [
"tests.fixtures.files",
"tests.fixtures.hub",
"tests.fixtures.optimizers",
"tests.plugins.reachy2_sdk",
]
@@ -71,19 +69,3 @@ def patch_builtins_input(monkeypatch):
print(text)
monkeypatch.setattr("builtins.input", print_text)
@pytest.fixture
def policy_feature_factory():
"""PolicyFeature factory"""
def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature:
return PolicyFeature(type=ft, shape=shape)
return _pf
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
assert isinstance(features, dict)
assert all(isinstance(k, str) for k in features.keys())
assert all(isinstance(v, PolicyFeature) for v in features.values())

View File

@@ -1,30 +0,0 @@
import sys
import types
from unittest.mock import MagicMock
def _install_reachy2_sdk_stub():
sdk = types.ModuleType("reachy2_sdk")
sdk.__path__ = []
sdk.ReachySDK = MagicMock(name="ReachySDK")
media = types.ModuleType("reachy2_sdk.media")
media.__path__ = []
camera = types.ModuleType("reachy2_sdk.media.camera")
camera.CameraView = MagicMock(name="CameraView")
camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager")
camera_manager.CameraManager = MagicMock(name="CameraManager")
sdk.media = media
media.camera = camera
media.camera_manager = camera_manager
# Register in sys.modules
sys.modules.setdefault("reachy2_sdk", sdk)
sys.modules.setdefault("reachy2_sdk.media", media)
sys.modules.setdefault("reachy2_sdk.media.camera", camera)
sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager)
def pytest_sessionstart(session):
_install_reachy2_sdk_stub()

View File

@@ -27,13 +27,11 @@ from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.policies.factory import (
get_policy_class,
@@ -365,54 +363,6 @@ def test_normalize(insert_temporal_dim):
unnormalize(output_batch)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""
Asserts that multiple keys with type State/Action are correctly processed by the policy constructor,
preventing erroneous creation of the policy object.
"""
input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
}
if multikey:
"""Simulates the complete state/action is constructed from more granular multiple
keys, of the same type as the overall state/action"""
input_features = {}
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
output_features["action"] = PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
)
config = ACTConfig(input_features=input_features, output_features=output_features)
state_condition = config.robot_state_feature == input_features[OBS_STATE]
action_condition = config.action_feature == output_features[ACTION]
assert state_condition, (
f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}"
)
assert action_condition, (
f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}"
)
@pytest.mark.parametrize(
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
[

View File

@@ -1,282 +0,0 @@
import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
def _dummy_batch():
"""Create a dummy batch using the new format with observation.* and next.* keys."""
return {
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.image.right": torch.randn(1, 3, 128, 128),
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
"action": torch.tensor([[0.5]]),
"next.reward": 1.0,
"next.done": False,
"next.truncated": False,
"info": {"key": "value"},
}
def test_observation_grouping_roundtrip():
"""Test that observation.* keys are properly grouped and ungrouped."""
proc = RobotProcessor([])
batch_in = _dummy_batch()
batch_out = proc(batch_in)
# Check that all observation.* keys are preserved
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
# Check tensor values
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
# Check other fields
assert torch.allclose(batch_out["action"], batch_in["action"])
assert batch_out["next.reward"] == batch_in["next.reward"]
assert batch_out["next.done"] == batch_in["next.done"]
assert batch_out["next.truncated"] == batch_in["next.truncated"]
assert batch_out["info"] == batch_in["info"]
def test_batch_to_transition_observation_grouping():
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
"action": "action_data",
"next.reward": 1.5,
"next.done": True,
"next.truncated": False,
"info": {"episode": 42},
}
transition = _default_batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
assert "observation.state" in transition[TransitionKey.OBSERVATION]
# Check values are preserved
assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
)
assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
)
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 1.5
assert transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"episode": 42}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_transition_to_batch_observation_flattening():
"""Test that _default_transition_to_batch correctly flattens observation dict."""
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
}
transition = {
TransitionKey.OBSERVATION: observation_dict,
TransitionKey.ACTION: "action_data",
TransitionKey.REWARD: 1.5,
TransitionKey.DONE: True,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {"episode": 42},
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch
assert "observation.image.left" in batch
assert "observation.state" in batch
# Check values are preserved
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
assert batch["observation.state"] == [1, 2, 3, 4]
# Check other fields are mapped to next.* format
assert batch["action"] == "action_data"
assert batch["next.reward"] == 1.5
assert batch["next.done"]
assert not batch["next.truncated"]
assert batch["info"] == {"episode": 42}
def test_no_observation_keys():
"""Test behavior when there are no observation.* keys."""
batch = {
"action": "action_data",
"next.reward": 2.0,
"next.done": False,
"next.truncated": True,
"info": {"test": "no_obs"},
}
transition = _default_batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionKey.OBSERVATION] is None
# Check other fields
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 2.0
assert not transition[TransitionKey.DONE]
assert transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
assert reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {"test": "no_obs"}
def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = _default_batch_to_transition(batch)
# Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionKey.ACTION] == "minimal_action"
# Check defaults
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_empty_batch():
"""Test behavior with empty batch."""
batch = {}
transition = _default_batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionKey.OBSERVATION] is None
assert transition[TransitionKey.ACTION] is None
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_complex_nested_observation():
"""Test with complex nested observation data."""
batch = {
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
"observation.state": torch.randn(7),
"action": torch.randn(8),
"next.reward": 3.14,
"next.done": False,
"next.truncated": True,
"info": {"episode_length": 200, "success": True},
}
transition = _default_batch_to_transition(batch)
reconstructed_batch = _default_transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch if k.startswith("observation.")}
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
assert original_obs_keys == reconstructed_obs_keys
# Check tensor values
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
# Check nested dict with tensors
assert torch.allclose(
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
)
assert torch.allclose(
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
)
# Check action tensor
assert torch.allclose(batch["action"], reconstructed_batch["action"])
# Check other fields
assert batch["next.reward"] == reconstructed_batch["next.reward"]
assert batch["next.done"] == reconstructed_batch["next.done"]
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
assert batch["info"] == reconstructed_batch["info"]
def test_custom_converter():
"""Test that custom converters can still be used."""
def to_tr(batch):
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
# Double the reward
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
return new_tr
def to_batch(tr):
batch = _default_transition_to_batch(tr)
return batch
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
batch = {
"observation.state": torch.randn(1, 4),
"action": torch.randn(1, 2),
"next.reward": 1.0,
"next.done": False,
}
result = processor(batch)
# Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0
assert torch.allclose(result["observation.state"], batch["observation.state"])
assert torch.allclose(result["action"], batch["action"])

View File

@@ -1,628 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.normalize_processor import (
NormalizerProcessor,
UnnormalizerProcessor,
_convert_stats_to_tensors,
)
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_numpy_conversion():
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5]))
assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2]))
def test_tensor_conversion():
stats = {
"action": {
"mean": torch.tensor([0.0, 0.0]),
"std": torch.tensor([1.0, 1.0]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert tensor_stats["action"]["mean"].dtype == torch.float32
assert tensor_stats["action"]["std"].dtype == torch.float32
def test_scalar_conversion():
stats = {
"reward": {
"mean": 0.5,
"std": 0.1,
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
def test_list_conversion():
stats = {
"observation.state": {
"min": [0.0, -1.0, -2.0],
"max": [1.0, 1.0, 2.0],
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
def test_unsupported_type():
stats = {
"bad_key": {
"mean": "string_value",
}
}
with pytest.raises(TypeError, match="Unsupported type"):
_convert_stats_to_tensors(stats)
# Helper functions to create feature maps and norm maps
def _create_observation_features():
return {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
def _create_observation_norm_map():
return {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Fixtures for observation normalisation tests using NormalizerProcessor
@pytest.fixture
def observation_stats():
return {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
@pytest.fixture
def observation_normalizer(observation_stats):
"""Return a NormalizerProcessor that only has observation stats (no action)."""
features = _create_observation_features()
norm_map = _create_observation_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
def test_mean_std_normalization(observation_normalizer):
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check mean/std normalization
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(normalized_obs["observation.image"], expected_image)
def test_min_max_normalization(observation_normalizer):
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check min/max normalization to [-1, 1]
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
# For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0
expected_state = torch.tensor([0.0, 0.0])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
def test_selective_normalization(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Only image should be normalized
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
# State should remain unchanged
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_device_compatibility(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
assert normalized_obs["observation.image"].device.type == "cuda"
def test_from_lerobot_dataset():
# Mock dataset
mock_dataset = Mock()
mock_dataset.meta.stats = {
"observation.image": {"mean": [0.5], "std": [0.2]},
"action": {"mean": [0.0], "std": [1.0]},
}
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"action": PolicyFeature(FeatureType.ACTION, (1,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
# Both observation and action statistics should be present in tensor stats
assert "observation.image" in normalizer._tensor_stats
assert "action" in normalizer._tensor_stats
def test_state_dict_save_load(observation_normalizer):
# Save state
state_dict = observation_normalizer.state_dict()
# Create new normalizer and load state
features = _create_observation_features()
norm_map = _create_observation_norm_map()
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
new_normalizer.load_state_dict(state_dict)
# Test that it works the same
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
transition = create_transition(observation=observation)
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(result1["observation.image"], result2["observation.image"])
# Fixtures for ActionUnnormalizer tests
@pytest.fixture
def action_stats_mean_std():
return {
"mean": np.array([0.0, 0.0, 0.0]),
"std": np.array([1.0, 2.0, 0.5]),
}
@pytest.fixture
def action_stats_min_max():
return {
"min": np.array([-1.0, -2.0, 0.0]),
"max": np.array([1.0, 2.0, 1.0]),
}
def _create_action_features():
return {
"action": PolicyFeature(FeatureType.ACTION, (3,)),
}
def _create_action_norm_map_mean_std():
return {
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
def _create_action_norm_map_min_max():
return {
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
def test_mean_std_unnormalization(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
normalized_action = torch.tensor([1.0, -0.5, 2.0])
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# action * std + mean
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
assert torch.allclose(unnormalized_action, expected)
def test_min_max_unnormalization(action_stats_min_max):
features = _create_action_features()
norm_map = _create_action_norm_map_min_max()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
)
# Actions in [-1, 1]
normalized_action = torch.tensor([0.0, -1.0, 1.0])
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# Map from [-1, 1] to [min, max]
# (action + 1) / 2 * (max - min) + min
expected = torch.tensor(
[
(0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0
(-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0
(1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0
]
)
assert torch.allclose(unnormalized_action, expected)
def test_numpy_action_input(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
assert isinstance(unnormalized_action, torch.Tensor)
expected = torch.tensor([1.0, -1.0, 1.0])
assert torch.allclose(unnormalized_action, expected)
def test_none_action(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
transition = create_transition()
result = unnormalizer(transition)
# Should return transition unchanged
assert result == transition
def test_action_from_lerobot_dataset():
mock_dataset = Mock()
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
assert "mean" in unnormalizer._tensor_stats["action"]
# Fixtures for NormalizerProcessor tests
@pytest.fixture
def full_stats():
return {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
"action": {
"mean": np.array([0.0, 0.0]),
"std": np.array([1.0, 2.0]),
},
}
def _create_full_features():
return {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
def _create_full_norm_map():
return {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
@pytest.fixture
def normalizer_processor(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
def test_combined_normalization(normalizer_processor):
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = normalizer_processor(transition)
# Check normalized observations
processed_obs = processed_transition[TransitionKey.OBSERVATION]
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(processed_obs["observation.image"], expected_image)
# Check normalized action
processed_action = processed_transition[TransitionKey.ACTION]
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
assert torch.allclose(processed_action, expected_action)
# Check other fields remain unchanged
assert processed_transition[TransitionKey.REWARD] == 1.0
assert not processed_transition[TransitionKey.DONE]
def test_processor_from_lerobot_dataset(full_stats):
# Mock dataset
mock_dataset = Mock()
mock_dataset.meta.stats = full_stats
features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor.from_lerobot_dataset(
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
)
assert processor.normalize_keys == {"observation.image"}
assert "observation.image" in processor._tensor_stats
assert "action" in processor._tensor_stats
def test_get_config(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor(
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
config = processor.get_config()
expected_config = {
"normalize_keys": ["observation.image"],
"eps": 1e-6,
"features": {
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
"observation.state": {"type": "STATE", "shape": (2,)},
"action": {"type": "ACTION", "shape": (2,)},
},
"norm_map": {
"VISUAL": "MEAN_STD",
"STATE": "MIN_MAX",
"ACTION": "MEAN_STD",
},
}
assert config == expected_config
def test_integration_with_robot_processor(normalizer_processor):
"""Test integration with RobotProcessor pipeline"""
robot_processor = RobotProcessor([normalizer_processor])
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = robot_processor(transition)
# Verify the processing worked
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
# Edge case tests
def test_empty_observation():
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
transition = create_transition()
result = normalizer(transition)
assert result == transition
def test_empty_stats():
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
observation = {"observation.image": torch.tensor([0.5])}
transition = create_transition(observation=observation)
result = normalizer(transition)
# Should return observation unchanged since no stats are available
assert torch.allclose(
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
)
def test_partial_stats():
"""If statistics are incomplete, the value should pass through unchanged."""
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7])}
transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(processed["observation.image"], observation["observation.image"])
def test_missing_action_stats_no_error():
mock_dataset = Mock()
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
# The tensor stats should not contain the 'action' key
assert "action" not in processor._tensor_stats
def test_serialization_roundtrip(full_stats):
"""Test that features and norm_map can be serialized and deserialized correctly."""
features = _create_full_features()
norm_map = _create_full_norm_map()
original_processor = NormalizerProcessor(
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
# Get config (serialization)
config = original_processor.get_config()
# Create a new processor from the config (deserialization)
new_processor = NormalizerProcessor(
features=config["features"],
norm_map=config["norm_map"],
stats=full_stats,
normalize_keys=set(config["normalize_keys"]),
eps=config["eps"],
)
# Test that both processors work the same way
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
result1 = original_processor(transition)
result2 = new_processor(transition)
# Compare results
assert torch.allclose(
result1[TransitionKey.OBSERVATION]["observation.image"],
result2[TransitionKey.OBSERVATION]["observation.image"],
)
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
# Verify features and norm_map are correctly reconstructed
assert new_processor.features.keys() == original_processor.features.keys()
for key in new_processor.features:
assert new_processor.features[key].type == original_processor.features[key].type
assert new_processor.features[key].shape == original_processor.features[key].shape
assert new_processor.norm_map == original_processor.norm_map

View File

@@ -1,486 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import VanillaObservationProcessor
from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_process_single_image():
"""Test processing a single image."""
processor = VanillaObservationProcessor()
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
processed_img = processed_obs["observation.image"]
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
assert processed_img.shape == (1, 3, 64, 64)
# Check dtype and range
assert processed_img.dtype == torch.float32
assert processed_img.min() >= 0.0
assert processed_img.max() <= 1.0
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = VanillaObservationProcessor()
# Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {"pixels": {"camera1": image1, "camera2": image2}}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
assert "observation.images.camera2" in processed_obs
# Check shapes
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
def test_process_batched_image():
"""Test processing already batched images."""
processor = VanillaObservationProcessor()
# Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = VanillaObservationProcessor()
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = VanillaObservationProcessor()
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
observation = {"pixels": image}
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve other data unchanged
assert "other_data" in processed_obs
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
def test_none_observation():
"""Test processor with None observation."""
processor = VanillaObservationProcessor()
transition = create_transition()
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = VanillaObservationProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_process_environment_state():
"""Test processing environment_state."""
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = VanillaObservationProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
def test_process_batched_states():
"""Test processing already batched states."""
processor = VanillaObservationProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "keep_me"
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve data unchanged
np.testing.assert_array_equal(processed_obs, observation)
def test_complete_observation_processing():
"""Test processing a complete observation with both images and states."""
processor = VanillaObservationProcessor()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "preserve_me",
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "preserve_me"
def test_image_only_processing():
"""Test processing observation with only images."""
processor = VanillaObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
def test_state_only_processing():
"""Test processing observation with only states."""
processor = VanillaObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
def test_empty_observation():
"""Test processing empty observation."""
processor = VanillaObservationProcessor()
observation = {}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert processed_obs == {}
def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = VanillaObservationProcessor()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = VanillaObservationProcessor()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
assert "pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
assert "observation.pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
}
out = processor.feature_contract(features.copy())
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
proc = VanillaObservationProcessor()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
}
out = proc.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert_contract_is_typed(out)

File diff suppressed because it is too large Load Diff

View File

@@ -1,467 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from pathlib import Path
import numpy as np
import torch
from lerobot.configs.types import FeatureType
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_basic_renaming():
"""Test basic key renaming functionality."""
rename_map = {
"old_key1": "new_key1",
"old_key2": "new_key2",
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"old_key1": torch.tensor([1.0, 2.0]),
"old_key2": np.array([3.0, 4.0]),
"unchanged_key": "keep_me",
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "new_key1" in processed_obs
assert "new_key2" in processed_obs
assert "old_key1" not in processed_obs
assert "old_key2" not in processed_obs
# Check values are preserved
torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0]))
np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0]))
# Check unchanged key is preserved
assert processed_obs["unchanged_key"] == "keep_me"
def test_empty_rename_map():
"""Test processor with empty rename map (should pass through unchanged)."""
processor = RenameProcessor(rename_map={})
observation = {
"key1": torch.tensor([1.0]),
"key2": "value2",
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# All keys should be unchanged
assert processed_obs.keys() == observation.keys()
torch.testing.assert_close(processed_obs["key1"], observation["key1"])
assert processed_obs["key2"] == observation["key2"]
def test_none_observation():
"""Test processor with None observation."""
processor = RenameProcessor(rename_map={"old": "new"})
transition = create_transition()
result = processor(transition)
# Should return transition unchanged
assert result == transition
def test_overlapping_rename():
"""Test renaming when new names might conflict."""
rename_map = {
"a": "b",
"b": "c", # This creates a potential conflict
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"a": 1,
"b": 2,
"x": 3,
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that renaming happens correctly
assert "a" not in processed_obs
assert processed_obs["b"] == 1 # 'a' renamed to 'b'
assert processed_obs["c"] == 2 # original 'b' renamed to 'c'
assert processed_obs["x"] == 3
def test_partial_rename():
"""Test renaming only some keys."""
rename_map = {
"observation.state": "observation.proprio_state",
"pixels": "observation.image",
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"observation.state": torch.randn(10),
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
"reward": 1.0,
"info": {"episode": 1},
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "observation.proprio_state" in processed_obs
assert "observation.image" in processed_obs
assert "observation.state" not in processed_obs
assert "pixels" not in processed_obs
# Check unchanged keys
assert processed_obs["reward"] == 1.0
assert processed_obs["info"] == {"episode": 1}
def test_get_config():
"""Test configuration serialization."""
rename_map = {
"old1": "new1",
"old2": "new2",
}
processor = RenameProcessor(rename_map=rename_map)
config = processor.get_config()
assert config == {"rename_map": rename_map}
def test_state_dict():
"""Test state dict (should be empty for RenameProcessor)."""
processor = RenameProcessor(rename_map={"old": "new"})
state = processor.state_dict()
assert state == {}
# Load state dict should work even with empty dict
processor.load_state_dict({})
def test_integration_with_robot_processor():
"""Test integration with RobotProcessor pipeline."""
rename_map = {
"agent_pos": "observation.state",
"pixels": "observation.image",
}
rename_processor = RenameProcessor(rename_map=rename_map)
pipeline = RobotProcessor([rename_processor])
observation = {
"agent_pos": np.array([1.0, 2.0, 3.0]),
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
"other_data": "preserve_me",
}
transition = create_transition(
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
)
result = pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renaming worked through pipeline
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
assert "agent_pos" not in processed_obs
assert "pixels" not in processed_obs
assert processed_obs["other_data"] == "preserve_me"
# Check other transition elements unchanged
assert result[TransitionKey.REWARD] == 0.5
assert result[TransitionKey.DONE] is False
def test_save_and_load_pretrained():
"""Test saving and loading processor with RobotProcessor."""
rename_map = {
"old_state": "observation.state",
"old_image": "observation.image",
}
processor = RenameProcessor(rename_map=rename_map)
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
assert config_path.exists()
# No state files should be created for RenameProcessor
state_files = list(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 0
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestRenameProcessor"
assert len(loaded_pipeline) == 1
# Check that loaded processor works correctly
loaded_processor = loaded_pipeline.steps[0]
assert isinstance(loaded_processor, RenameProcessor)
assert loaded_processor.rename_map == rename_map
# Test functionality after loading
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
transition = create_transition(observation=observation)
result = loaded_pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
assert processed_obs["observation.state"] == [1, 2, 3]
assert processed_obs["observation.image"] == "image_data"
def test_registry_functionality():
"""Test that RenameProcessor is properly registered."""
# Check that it's registered
assert "rename_processor" in ProcessorStepRegistry.list()
# Get from registry
retrieved_class = ProcessorStepRegistry.get("rename_processor")
assert retrieved_class is RenameProcessor
# Create instance from registry
instance = retrieved_class(rename_map={"old": "new"})
assert isinstance(instance, RenameProcessor)
assert instance.rename_map == {"old": "new"}
def test_registry_based_save_load():
"""Test save/load using registry name instead of module path."""
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
pipeline = RobotProcessor([processor])
with tempfile.TemporaryDirectory() as tmp_dir:
# Save and load
pipeline.save_pretrained(tmp_dir)
# Verify config uses registry name
import json
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
config = json.load(f)
assert "registry_name" in config["steps"][0]
assert config["steps"][0]["registry_name"] == "rename_processor"
assert "class" not in config["steps"][0] # Should use registry, not module path
# Load should work
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
loaded_processor = loaded_pipeline.steps[0]
assert isinstance(loaded_processor, RenameProcessor)
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
def test_chained_rename_processors():
"""Test multiple RenameProcessors in a pipeline."""
# First processor: rename raw keys to intermediate format
processor1 = RenameProcessor(
rename_map={
"pos": "agent_position",
"img": "camera_image",
}
)
# Second processor: rename to final format
processor2 = RenameProcessor(
rename_map={
"agent_position": "observation.state",
"camera_image": "observation.image",
}
)
pipeline = RobotProcessor([processor1, processor2])
observation = {
"pos": np.array([1.0, 2.0]),
"img": "image_data",
"extra": "keep_me",
}
transition = create_transition(observation=observation)
# Step through to see intermediate results
results = list(pipeline.step_through(transition))
# After first processor
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
# After second processor
final_obs = results[2][TransitionKey.OBSERVATION]
assert "observation.state" in final_obs
assert "observation.image" in final_obs
assert final_obs["extra"] == "keep_me"
# Original keys should be gone
assert "pos" not in final_obs
assert "img" not in final_obs
assert "agent_position" not in final_obs
assert "camera_image" not in final_obs
def test_nested_observation_rename():
"""Test renaming with nested observation structures."""
rename_map = {
"observation.images.left": "observation.camera.left_view",
"observation.images.right": "observation.camera.right_view",
"observation.proprio": "observation.proprioception",
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"observation.images.left": torch.randn(3, 64, 64),
"observation.images.right": torch.randn(3, 64, 64),
"observation.proprio": torch.randn(7),
"observation.gripper": torch.tensor([0.0]), # Not renamed
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renames
assert "observation.camera.left_view" in processed_obs
assert "observation.camera.right_view" in processed_obs
assert "observation.proprioception" in processed_obs
# Check unchanged key
assert "observation.gripper" in processed_obs
# Check old keys removed
assert "observation.images.left" not in processed_obs
assert "observation.images.right" not in processed_obs
assert "observation.proprio" not in processed_obs
def test_value_types_preserved():
"""Test that various value types are preserved during renaming."""
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
processor = RenameProcessor(rename_map=rename_map)
tensor_value = torch.randn(3, 3)
array_value = np.random.rand(2, 2)
observation = {
"old_tensor": tensor_value,
"old_array": array_value,
"old_scalar": 42,
"old_string": "hello",
"old_dict": {"nested": "value"},
"old_list": [1, 2, 3],
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that values and types are preserved
assert torch.equal(processed_obs["new_tensor"], tensor_value)
assert np.array_equal(processed_obs["new_array"], array_value)
assert processed_obs["new_scalar"] == 42
assert processed_obs["old_string"] == "hello"
assert processed_obs["old_dict"] == {"nested": "value"}
assert processed_obs["old_list"] == [1, 2, 3]
def test_feature_contract_basic_renaming(policy_feature_factory):
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (2,)),
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
"c": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
# Values preserved and typed
assert out["x"] == features["a"]
assert out["y"] == features["b"]
assert out["c"] == features["c"]
assert_contract_is_typed(out)
# Input not mutated
assert set(features) == {"a", "b", "c"}
def test_feature_contract_overlapping_keys(policy_feature_factory):
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (1,)),
"b": policy_feature_factory(FeatureType.STATE, (2,)),
}
out = processor.feature_contract(features)
assert set(out) == {"b", "c"}
assert out["b"] == features["a"] # 'a' renamed to'b'
assert out["c"] == features["b"] # 'b' renamed to 'c'
assert_contract_is_typed(out)
def test_feature_contract_chained_processors(policy_feature_factory):
# Chain two rename processors at the contract level
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameProcessor(
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
)
pipeline = RobotProcessor([processor1, processor2])
spec = {
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = pipeline.feature_contract(initial_features=spec)
assert set(out) == {"observation.state", "observation.image", "extra"}
assert out["observation.state"] == spec["pos"]
assert out["observation.image"] == spec["img"]
assert out["extra"] == spec["extra"]
assert_contract_is_typed(out)

View File

@@ -1,326 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.robots.reachy2 import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Robot,
Reachy2RobotConfig,
)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_JOINTS = {
**REACHY2_NECK_JOINTS,
**REACHY2_ANTENNAS_JOINTS,
**REACHY2_R_ARM_JOINTS,
**REACHY2_L_ARM_JOINTS,
}
PARAMS = [
{}, # default config
{"with_mobile_base": False},
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
{"use_external_commands": True, "disable_torque_on_disconnect": True},
{"use_external_commands": True, "with_mobile_base": False, "with_neck": False},
{"disable_torque_on_disconnect": False},
{"max_relative_target": 5},
{"with_right_teleop_camera": False},
{"with_left_teleop_camera": False, "with_right_teleop_camera": False},
{"with_left_teleop_camera": False, "with_torso_camera": True},
]
def _make_reachy2_sdk_mock():
class JointSpy:
__slots__ = (
"present_position",
"_goal_position",
"_on_set",
)
def __init__(self, present_position=0.0, on_set=None):
self.present_position = present_position
self._goal_position = present_position
self._on_set = on_set
@property
def goal_position(self):
return self._goal_position
@goal_position.setter
def goal_position(self, v):
self._goal_position = v
if self._on_set:
self._on_set()
r = MagicMock(name="ReachySDKMock")
r.is_connected.return_value = True
def _connect():
r.is_connected.return_value = True
def _disconnect():
r.is_connected.return_value = False
# Global counter of goal_position sets
r._goal_position_set_total = 0
def _on_any_goal_set():
r._goal_position_set_total += 1
# Mock joints with some dummy positions
joints = {
k: JointSpy(
present_position=float(i),
on_set=_on_any_goal_set,
)
for i, k in enumerate(REACHY2_JOINTS.values())
}
r.joints = joints
# Mock mobile base with some dummy odometry
r.mobile_base = MagicMock()
r.mobile_base.odometry = {
"x": 0.1,
"y": -0.2,
"theta": 21.3,
"vx": 0.001,
"vy": 0.002,
"vtheta": 0.0,
}
r.connect = MagicMock(side_effect=_connect)
r.disconnect = MagicMock(side_effect=_disconnect)
# Mock methods
r.turn_on = MagicMock()
r.reset_default_limits = MagicMock()
r.send_goal_positions = MagicMock()
r.turn_off_smoothly = MagicMock()
r.mobile_base.set_goal_speed = MagicMock()
r.mobile_base.send_speed_command = MagicMock()
return r
def _make_reachy2_camera_mock(*args, **kwargs):
cfg = args[0] if args else kwargs.get("config")
name = getattr(cfg, "name", kwargs.get("name", "cam"))
image_type = getattr(cfg, "image_type", kwargs.get("image_type", "cam"))
width = getattr(cfg, "width", kwargs.get("width", 640))
height = getattr(cfg, "height", kwargs.get("height", 480))
cam = MagicMock(name=f"Reachy2CameraMock:{name}")
cam.name = name
cam.image_type = image_type
cam.width = width
cam.height = height
cam.connect = MagicMock()
cam.disconnect = MagicMock()
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
return cam
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
def reachy2(request):
with (
patch(
"lerobot.robots.reachy2.robot_reachy2.ReachySDK",
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
),
patch(
"lerobot.cameras.reachy2_camera.reachy2_camera.Reachy2Camera",
side_effect=_make_reachy2_camera_mock,
),
):
overrides = request.param
cfg = Reachy2RobotConfig(ip_address="192.168.0.200", **overrides)
robot = Reachy2Robot(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(reachy2):
assert not reachy2.is_connected
reachy2.connect()
assert reachy2.is_connected
reachy2.reachy.turn_on.assert_called_once()
reachy2.reachy.reset_default_limits.assert_called_once()
reachy2.disconnect()
assert not reachy2.is_connected
if reachy2.config.disable_torque_on_disconnect:
reachy2.reachy.turn_off_smoothly.assert_called_once()
else:
reachy2.reachy.turn_off_smoothly.assert_not_called()
reachy2.reachy.disconnect.assert_called_once()
def test_get_joints_dict(reachy2):
reachy2.connect()
if reachy2.config.with_neck:
assert "neck_yaw.pos" in reachy2.joints_dict
assert "neck_pitch.pos" in reachy2.joints_dict
assert "neck_roll.pos" in reachy2.joints_dict
else:
assert "neck_yaw.pos" not in reachy2.joints_dict
assert "neck_pitch.pos" not in reachy2.joints_dict
assert "neck_roll.pos" not in reachy2.joints_dict
if reachy2.config.with_antennas:
assert "l_antenna.pos" in reachy2.joints_dict
assert "r_antenna.pos" in reachy2.joints_dict
else:
assert "l_antenna.pos" not in reachy2.joints_dict
assert "r_antenna.pos" not in reachy2.joints_dict
if reachy2.config.with_r_arm:
assert "r_shoulder_pitch.pos" in reachy2.joints_dict
assert "r_shoulder_roll.pos" in reachy2.joints_dict
assert "r_elbow_yaw.pos" in reachy2.joints_dict
assert "r_elbow_pitch.pos" in reachy2.joints_dict
assert "r_wrist_roll.pos" in reachy2.joints_dict
assert "r_wrist_pitch.pos" in reachy2.joints_dict
assert "r_wrist_yaw.pos" in reachy2.joints_dict
assert "r_gripper.pos" in reachy2.joints_dict
else:
assert "r_shoulder_pitch.pos" not in reachy2.joints_dict
assert "r_shoulder_roll.pos" not in reachy2.joints_dict
assert "r_elbow_yaw.pos" not in reachy2.joints_dict
assert "r_elbow_pitch.pos" not in reachy2.joints_dict
assert "r_wrist_roll.pos" not in reachy2.joints_dict
assert "r_wrist_pitch.pos" not in reachy2.joints_dict
assert "r_wrist_yaw.pos" not in reachy2.joints_dict
assert "r_gripper.pos" not in reachy2.joints_dict
if reachy2.config.with_l_arm:
assert "l_shoulder_pitch.pos" in reachy2.joints_dict
assert "l_shoulder_roll.pos" in reachy2.joints_dict
assert "l_elbow_yaw.pos" in reachy2.joints_dict
assert "l_elbow_pitch.pos" in reachy2.joints_dict
assert "l_wrist_roll.pos" in reachy2.joints_dict
assert "l_wrist_pitch.pos" in reachy2.joints_dict
assert "l_wrist_yaw.pos" in reachy2.joints_dict
assert "l_gripper.pos" in reachy2.joints_dict
else:
assert "l_shoulder_pitch.pos" not in reachy2.joints_dict
assert "l_shoulder_roll.pos" not in reachy2.joints_dict
assert "l_elbow_yaw.pos" not in reachy2.joints_dict
assert "l_elbow_pitch.pos" not in reachy2.joints_dict
assert "l_wrist_roll.pos" not in reachy2.joints_dict
assert "l_wrist_pitch.pos" not in reachy2.joints_dict
assert "l_wrist_yaw.pos" not in reachy2.joints_dict
assert "l_gripper.pos" not in reachy2.joints_dict
def test_get_observation(reachy2):
reachy2.connect()
obs = reachy2.get_observation()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
expected_keys.update(reachy2.cameras.keys())
assert set(obs.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
if reachy2.config.with_mobile_base:
for vel in REACHY2_VEL.keys():
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
if reachy2.config.with_left_teleop_camera:
assert obs["teleop_left"].shape == (
reachy2.config.cameras["teleop_left"].height,
reachy2.config.cameras["teleop_left"].width,
3,
)
if reachy2.config.with_right_teleop_camera:
assert obs["teleop_right"].shape == (
reachy2.config.cameras["teleop_right"].height,
reachy2.config.cameras["teleop_right"].width,
3,
)
if reachy2.config.with_torso_camera:
assert obs["torso_rgb"].shape == (
reachy2.config.cameras["torso_rgb"].height,
reachy2.config.cameras["torso_rgb"].width,
3,
)
def test_send_action(reachy2):
reachy2.connect()
action = {k: i * 10.0 for i, k in enumerate(reachy2.joints_dict.keys(), start=1)}
if reachy2.config.with_mobile_base:
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
previous_present_position = {
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
}
returned = reachy2.send_action(action)
if reachy2.config.max_relative_target is None:
assert returned == action
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
for motor in reachy2.joints_dict.keys():
expected_pos = action[motor]
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.max_relative_target is None:
assert real_pos == expected_pos
else:
assert real_pos == previous_present_position[motor] + np.sign(expected_pos) * min(
abs(expected_pos - real_pos), reachy2.config.max_relative_target
)
if reachy2.config.with_mobile_base:
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
if reachy2.config.use_external_commands:
reachy2.reachy.send_goal_positions.assert_not_called()
if reachy2.config.with_mobile_base:
reachy2.reachy.mobile_base.send_speed_command.assert_not_called()
else:
reachy2.reachy.send_goal_positions.assert_called_once()
if reachy2.config.with_mobile_base:
reachy2.reachy.mobile_base.send_speed_command.assert_called_once()
def test_no_part_declared():
with pytest.raises(ValueError):
_ = Reachy2RobotConfig(
ip_address="192.168.0.200",
with_mobile_base=False,
with_l_arm=False,
with_r_arm=False,
with_neck=False,
with_antennas=False,
)

View File

@@ -1,150 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import pytest
from lerobot.teleoperators.reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Teleoperator,
Reachy2TeleoperatorConfig,
)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_JOINTS = {
**REACHY2_NECK_JOINTS,
**REACHY2_ANTENNAS_JOINTS,
**REACHY2_R_ARM_JOINTS,
**REACHY2_L_ARM_JOINTS,
}
PARAMS = [
{}, # default config
{"with_mobile_base": False},
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
{"with_mobile_base": False, "with_neck": False},
{"use_present_position": True},
]
def _make_reachy2_sdk_mock():
r = MagicMock(name="ReachySDKMock")
r.is_connected.return_value = True
def _connect():
r.is_connected.return_value = True
def _disconnect():
r.is_connected.return_value = False
# Mock joints with some dummy positions
joints = {
k: MagicMock(
present_position=float(i),
goal_position=float(i) + 0.5,
)
for i, k in enumerate(REACHY2_JOINTS.values())
}
r.joints = joints
# Mock mobile base with some dummy odometry
r.mobile_base = MagicMock()
r.mobile_base.last_cmd_vel = {
"vx": -0.2,
"vy": 0.2,
"vtheta": 11.0,
}
r.mobile_base.odometry = {
"x": 1.0,
"y": 2.0,
"theta": 20.0,
"vx": 0.1,
"vy": -0.1,
"vtheta": 8.0,
}
r.connect = MagicMock(side_effect=_connect)
r.disconnect = MagicMock(side_effect=_disconnect)
return r
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
def reachy2(request):
with (
patch(
"lerobot.teleoperators.reachy2_teleoperator.reachy2_teleoperator.ReachySDK",
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
),
):
overrides = request.param
cfg = Reachy2TeleoperatorConfig(ip_address="192.168.0.200", **overrides)
robot = Reachy2Teleoperator(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(reachy2):
assert not reachy2.is_connected
reachy2.connect()
assert reachy2.is_connected
reachy2.disconnect()
assert not reachy2.is_connected
reachy2.reachy.disconnect.assert_called_once()
def test_get_action(reachy2):
reachy2.connect()
action = reachy2.get_action()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
assert set(action.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
if reachy2.config.use_present_position:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
else:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.with_mobile_base:
if reachy2.config.use_present_position:
for vel in REACHY2_VEL.keys():
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
else:
for vel in REACHY2_VEL.keys():
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
def test_no_part_declared():
with pytest.raises(ValueError):
_ = Reachy2TeleoperatorConfig(
ip_address="192.168.0.200",
with_mobile_base=False,
with_l_arm=False,
with_r_arm=False,
with_neck=False,
with_antennas=False,
)