Compare commits

..

9 Commits

Author SHA1 Message Date
Pepijn
a07ff640bb fix metadata info.json 2026-02-05 17:09:01 +01:00
Pepijn
0753415244 fix 2026-02-05 16:56:47 +01:00
Pepijn
a054663e38 cleanup output from earlier runs 2026-02-05 16:54:16 +01:00
Pepijn
aceb651e40 import logging 2026-02-05 16:15:53 +01:00
Pepijn
76a4529d29 fix bug 2026-02-05 16:03:45 +01:00
Pepijn
39e14c086c add push to hub 2026-02-05 15:20:57 +01:00
Pepijn
0af2029328 add slurm mirror dataset 2026-02-05 15:15:10 +01:00
Pepijn
c027b2971c Merge branch 'main' into tmp/fold_training 2026-01-31 13:52:19 +01:00
Michel Aractingi
9cc203034e temp_training 2026-01-26 15:19:52 +00:00
84 changed files with 944 additions and 1441 deletions

View File

@@ -101,11 +101,9 @@ jobs:
runs-on:
group: aws-general-8-plus
if: |
github.repository == 'huggingface/lerobot' && (
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
)
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
outputs:
image_tag: ${{ steps.set_tag.outputs.image_tag }}
env:

View File

@@ -91,7 +91,6 @@ jobs:
name: Build and Push Docker
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
env:

View File

@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
aliberts/aloha_mobile_shrimp_image \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 2 20 None \
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
--vcodec libsvtav1 \
--pix-fmt yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_size_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
| | | vcodec | pix_fmt | | | |
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |

View File

@@ -185,7 +185,7 @@ echo $HF_USER
Use the standard recording command:
```bash
lerobot-record \
python src/lerobot/scripts/lerobot_record.py \
--robot.type=earthrover_mini_plus \
--teleop.type=keyboard_rover \
--dataset.repo_id=your_username/dataset_name \

View File

@@ -224,7 +224,7 @@ lerobot-record \
--teleop.port=/dev/tty.usbmodem1201 \
--teleop.id=right \
--teleop.side=right \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--dataset.single_task="Hand recording test with video data" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
@@ -241,7 +241,7 @@ lerobot-replay \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
--robot.side=right \
--dataset.repo_id=<USER>/hand_record_test_with_camera \
--dataset.repo_id=nepyope/hand_record_test_with_camera \
--dataset.episode=0
```
@@ -249,13 +249,13 @@ lerobot-replay \
```bash
lerobot-train \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--policy.type=act \
--output_dir=outputs/train/hopejr_hand \
--job_name=hopejr \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=<USER>/hand_test_policy
--policy.repo_id=nepyope/hand_test_policy
```
### Evaluate
@@ -270,7 +270,7 @@ lerobot-record \
--robot.side=right \
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
--display_data=false \
--dataset.repo_id=<USER>/eval_hopejr \
--dataset.repo_id=nepyope/eval_hopejr \
--dataset.single_task="Evaluate hopejr hand policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model

View File

@@ -1,15 +1,13 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Step 2: Environment Setup
## Environment Setup
Create a virtual environment with Python 3.10, using conda:
@@ -40,7 +38,7 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
## Step 3: Install LeRobot 🤗
## Install LeRobot 🤗
### From Source

View File

@@ -60,7 +60,7 @@ policy.type=pi0
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--output_dir=./outputs/pi0_training \

View File

@@ -56,7 +56,7 @@ policy.type=pi05
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py\
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--output_dir=./outputs/pi05_training \

View File

@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
Train with **no annotations** - uses linear progress from 0 to 1:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=single_stage \
@@ -288,7 +288,7 @@ lerobot-train \
Train with **dense annotations only** (sparse auto-generated):
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dense_only \
@@ -307,7 +307,7 @@ lerobot-train \
Train with **both sparse and dense annotations**:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dual \
@@ -468,7 +468,7 @@ This script:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \

View File

@@ -216,7 +216,7 @@ lerobot-teleoperate \
### Record Dataset in Simulation
```bash
lerobot-record \
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
@@ -266,7 +266,7 @@ lerobot-teleoperate \
### Record Dataset on Real Robot
```bash
lerobot-record \
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \

View File

@@ -12,7 +12,6 @@ LeRobot provides several utilities for manipulating datasets:
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
@@ -157,30 +156,6 @@ lerobot-edit-dataset \
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Show the information of datasets
Show the information of datasets such as number of episode, number of frame, File size and so on.
No change will be made to the dataset
```bash
# Show dataset information without feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
# Show dataset information with feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features true
```
**Parameters:**
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
### Push to Hub
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:

View File

@@ -45,7 +45,7 @@ policy.type=wall_x
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=wall_x \
--output_dir=./outputs/wallx_training \

View File

@@ -154,7 +154,7 @@ lerobot-train \
```bash
lerobot-train \
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \

View File

@@ -22,7 +22,7 @@ lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=<USER>/record-test \
--dataset.repo_id=aliberts/record-test \
--dataset.episode=2
```
"""

File diff suppressed because it is too large Load Diff

View File

@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
Usage:
# Basic usage with smolvla policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
@@ -58,16 +58,16 @@ Usage:
--device=cuda
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/reuben_pi0 \
--dataset.repo_id=<USER>/so101_cube_in_cup \
--policy.path=lipsop/reuben_pi0 \
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
@@ -75,8 +75,8 @@ Usage:
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
@@ -84,8 +84,8 @@ Usage:
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \

View File

@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
@@ -41,7 +41,7 @@ Usage:
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
@@ -53,7 +53,7 @@ Usage:
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/pi05_check_rtc \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \

View File

@@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Hugging Face dependencies
"datasets>=4.0.0,<5.0.0",
"datasets>=4.0.0,<4.2.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"accelerate>=1.10.0,<2.0.0",
@@ -76,9 +76,9 @@ dependencies = [
"pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0",
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.1.1,<2.0.0",
@@ -360,9 +360,9 @@ ignore_errors = false
module = "lerobot.cameras.*"
ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.motors.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"

View File

@@ -13,5 +13,5 @@
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .utils import make_cameras_from_configs

View File

@@ -150,7 +150,7 @@ class Camera(abc.ABC):
"""
pass
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the

View File

@@ -25,10 +25,6 @@ class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
@@ -36,25 +32,6 @@ class Cv2Rotation(int, Enum):
ROTATE_180 = 180
ROTATE_270 = -90
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
class Cv2Backends(int, Enum):
ANY = 0
V4L2 = 200
DSHOW = 700
PVAPI = 800
ANDROID = 1000
AVFOUNDATION = 1200
MSMF = 1400
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus

View File

@@ -32,11 +32,10 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_rotation
from ..utils import get_cv2_backend, get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
@@ -118,7 +117,7 @@ class OpenCVCamera(Camera):
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = config.backend
self.backend: int = get_cv2_backend()
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
@@ -133,7 +132,6 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
@@ -150,6 +148,8 @@ class OpenCVCamera(Camera):
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
@@ -178,7 +178,6 @@ class OpenCVCamera(Camera):
logger.info(f"{self} connected.")
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
@@ -198,6 +197,8 @@ class OpenCVCamera(Camera):
to the requested value.
DeviceNotConnectedError: If the camera is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
@@ -347,7 +348,6 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -374,6 +374,9 @@ class OpenCVCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -487,7 +490,6 @@ class OpenCVCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -510,6 +512,8 @@ class OpenCVCamera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -529,8 +533,7 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -545,6 +548,8 @@ class OpenCVCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")

View File

@@ -15,9 +15,9 @@
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from ..configs import CameraConfig, ColorMode, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
@CameraConfig.register_subclass("opencv")
@@ -50,7 +50,6 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
@@ -63,12 +62,22 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
backend: Cv2Backends = Cv2Backends.ANY
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
self.backend = Cv2Backends(self.backend)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(

View File

@@ -74,4 +74,7 @@ class Reachy2CameraConfig(CameraConfig):
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)

View File

@@ -32,7 +32,6 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
@@ -124,7 +123,6 @@ class Reachy2Camera(Camera):
"""
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -138,6 +136,9 @@ class Reachy2Camera(Camera):
"""
start_time = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -183,7 +184,6 @@ class Reachy2Camera(Camera):
return frame
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Same as read()
@@ -197,11 +197,12 @@ class Reachy2Camera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
return self.read()
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -218,6 +219,8 @@ class Reachy2Camera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
@@ -230,7 +233,6 @@ class Reachy2Camera(Camera):
return self.latest_frame
@check_if_not_connected
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -238,6 +240,8 @@ class Reachy2Camera(Camera):
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.cam_manager is not None:
self.cam_manager.disconnect()

View File

@@ -30,8 +30,7 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -153,7 +152,6 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -171,6 +169,8 @@ class RealSenseCamera(Camera):
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
@@ -290,7 +290,6 @@ class RealSenseCamera(Camera):
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
@@ -300,6 +299,8 @@ class RealSenseCamera(Camera):
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
@@ -319,7 +320,6 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
@check_if_not_connected
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -345,6 +345,9 @@ class RealSenseCamera(Camera):
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -371,7 +374,6 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
@@ -401,6 +403,9 @@ class RealSenseCamera(Camera):
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -529,7 +534,6 @@ class RealSenseCamera(Camera):
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -552,6 +556,8 @@ class RealSenseCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -572,8 +578,7 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -588,6 +593,8 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")

View File

@@ -60,8 +60,20 @@ class RealSenseCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from typing import cast
from lerobot.utils.import_utils import make_device_from_device_class
@@ -67,3 +68,14 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)

View File

@@ -34,8 +34,7 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -105,7 +104,6 @@ class ZMQCamera(Camera):
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server.
@@ -113,6 +111,8 @@ class ZMQCamera(Camera):
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
logger.info(f"Connecting to {self}...")
@@ -211,7 +211,6 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -229,6 +228,9 @@ class ZMQCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -299,7 +301,6 @@ class ZMQCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -316,6 +317,8 @@ class ZMQCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -332,7 +335,6 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
@@ -348,6 +350,8 @@ class ZMQCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")

View File

@@ -32,7 +32,10 @@ class ZMQCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")

View File

@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1

View File

@@ -72,10 +72,11 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
raise ValueError(
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
)
if features != meta.features:
raise ValueError(
f"Same features is expected, but got features={meta.features} instead of {features}."
)
# TODO: Temporarily disabled for merging datasets with different features (e.g. shirt_id)
# if features != meta.features:
# raise ValueError(
# f"Same features is expected, but got features={meta.features} instead of {features}."
# )
return fps, robot_type, features

View File

@@ -563,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
tolerance_s: float = 1e-2,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
@@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
@@ -1572,7 +1572,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
robot_type: str | None = None,
use_videos: bool = True,
tolerance_s: float = 1e-4,
tolerance_s: float = 1e-2,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
video_backend: str | None = None,

View File

@@ -216,17 +216,16 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "SharpnessJitter":
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
class ImageTransforms(Transform):

View File

@@ -122,9 +122,19 @@ def load_nested_dataset(
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
with SuppressProgressBars():
# We use .from_parquet() memory-mapped loading for efficiency
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int:

View File

@@ -529,7 +529,7 @@ if __name__ == "__main__":
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",

View File

@@ -205,7 +205,6 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False

View File

@@ -112,7 +112,6 @@ class LiberoEnv(gym.Env):
visualization_height: int = 480,
init_states: bool = True,
episode_index: int = 0,
n_envs: int = 1,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
@@ -146,9 +145,7 @@ class LiberoEnv(gym.Env):
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
@@ -298,8 +295,7 @@ class LiberoEnv(gym.Env):
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
self.init_state_id += self._reset_stride # Change init_state_id when reset
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
@@ -377,7 +373,6 @@ def _make_env_fns(
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
n_envs=n_envs,
control_mode=control_mode,
**local_kwargs,
)

View File

@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(self.groups)
self.group_names = list(groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
@@ -230,20 +230,18 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")

View File

@@ -23,7 +23,6 @@ from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
@@ -37,6 +36,7 @@ else:
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
@@ -155,7 +155,6 @@ class DamiaoMotorsBus(MotorsBusBase):
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
@@ -163,6 +162,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
@@ -208,9 +211,6 @@ class DamiaoMotorsBus(MotorsBusBase):
logger.info("Starting handshake with motors...")
# Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01):
pass
@@ -246,7 +246,6 @@ class DamiaoMotorsBus(MotorsBusBase):
)
logger.info("Handshake successful. All motors ready.")
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
@@ -254,6 +253,8 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
if disable_torque:
try:
@@ -282,10 +283,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
@@ -344,10 +341,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
@@ -363,10 +356,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
CAN message if received, None otherwise
"""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
start_time = time.time()
messages_seen = []
@@ -405,13 +394,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
Dictionary mapping recv_id to CAN message
"""
responses: dict[int, can.Message] = {}
responses = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
@@ -475,9 +461,6 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
@@ -505,9 +488,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
@@ -582,9 +562,10 @@ class DamiaoMotorsBus(MotorsBusBase):
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
@check_if_not_connected
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
@@ -614,7 +595,6 @@ class DamiaoMotorsBus(MotorsBusBase):
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
@check_if_not_connected
def write(
self,
data_name: str,
@@ -625,6 +605,8 @@ class DamiaoMotorsBus(MotorsBusBase):
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
@@ -674,10 +656,6 @@ class DamiaoMotorsBus(MotorsBusBase):
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
@@ -700,12 +678,10 @@ class DamiaoMotorsBus(MotorsBusBase):
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
@check_if_not_connected
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
@@ -714,8 +690,6 @@ class DamiaoMotorsBus(MotorsBusBase):
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
@@ -758,9 +732,9 @@ class DamiaoMotorsBus(MotorsBusBase):
def record_ranges_of_motion(
self,
motors: str | list[str] | None = None,
motors: NameOrID | list[NameOrID] | None = None,
display_values: bool = True,
) -> tuple[dict[str, Value], dict[str, Value]]:
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
Interactively record the min/max values of each motor in degrees.

View File

@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=int(drive_modes[motor]),
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
return {id_: data[0] for id_, data in data_list.items()}

View File

@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
return half_turn_homings
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list: dict[int, int] = {}
data_list = {}
status_length = 6
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:

View File

@@ -23,7 +23,6 @@ from __future__ import annotations
import abc
import logging
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -94,7 +93,7 @@ class MotorsBusBase(abc.ABC):
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@@ -180,16 +179,15 @@ class Motor:
class PortHandler(Protocol):
is_open: bool
baudrate: int
packet_start_time: float
packet_timeout: float
tx_time_per_byte: float
is_using: bool
port_name: str
ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
def openPort(self): ...
def closePort(self): ...
@@ -242,22 +240,19 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
last_result: bool
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
@@ -270,17 +265,15 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
@@ -407,7 +400,7 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> str:
def _get_motor_model(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
@@ -415,19 +408,17 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return list(self.motors)
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, int):
return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
elif isinstance(motors, list):
return motors.copy()
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
@@ -649,19 +640,18 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors.
Args:
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
"""
pass
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
def torque_disabled(self, motors: int | str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -738,19 +728,24 @@ class SerialMotorsBus(MotorsBusBase):
"""
pass
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared.
Args:
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
for motor in motor_names:
for motor in motors:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
@@ -759,9 +754,7 @@ class SerialMotorsBus(MotorsBusBase):
self.calibration = {}
def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one
@@ -771,12 +764,17 @@ class SerialMotorsBus(MotorsBusBase):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns:
dict[str, Value]: Mapping *motor name → written homing offset*.
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
@@ -788,8 +786,8 @@ class SerialMotorsBus(MotorsBusBase):
pass
def record_ranges_of_motion(
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[str, Value], dict[str, Value]]:
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -801,25 +799,30 @@ class SerialMotorsBus(MotorsBusBase):
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns:
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motor_names, normalize=False)
positions = self.sync_read("Present_Position", motors, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motor_names:
for motor in motors:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed():
@@ -827,9 +830,9 @@ class SerialMotorsBus(MotorsBusBase):
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motor_names) + 3)
move_cursor_up(len(motors) + 3)
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -952,12 +955,12 @@ class SerialMotorsBus(MotorsBusBase):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return None
return
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return None
return
return model_number
@@ -1004,13 +1007,12 @@ class SerialMotorsBus(MotorsBusBase):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
decoded = self._decode_sign(data_name, {id_: value})
id_value = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return normalized[id_]
id_value = self._normalize(id_value)
return decoded[id_]
return id_value[id_]
def _read(
self,
@@ -1021,7 +1023,7 @@ class SerialMotorsBus(MotorsBusBase):
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int, int]:
) -> tuple[int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -1071,14 +1073,13 @@ class SerialMotorsBus(MotorsBusBase):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data:
int_value = self._unnormalize({id_: value})[id_]
value = self._unnormalize({id_: value})[id_]
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
@@ -1112,7 +1113,7 @@ class SerialMotorsBus(MotorsBusBase):
def sync_read(
self,
data_name: str,
motors: NameOrID | Sequence[NameOrID] | None = None,
motors: str | list[str] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
@@ -1121,7 +1122,7 @@ class SerialMotorsBus(MotorsBusBase):
Args:
data_name (str): Register name.
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1142,17 +1143,16 @@ class SerialMotorsBus(MotorsBusBase):
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
raw_ids_values, _ = self._sync_read(
ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
decoded = self._decode_sign(data_name, raw_ids_values)
ids_values = self._decode_sign(data_name, ids_values)
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
ids_values = self._normalize(ids_values)
return {self._id_to_name(id_): value for id_, value in decoded.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read(
self,
@@ -1224,24 +1224,21 @@ class SerialMotorsBus(MotorsBusBase):
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in raw_ids_values]
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data:
int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._unnormalize(ids_values)
int_ids_values = self._encode_sign(data_name, int_ids_values)
ids_values = self._encode_sign(data_name, ids_values)
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _sync_write(
self,

View File

@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- Either:
@@ -48,12 +48,21 @@ class ACTConfig(PreTrainedConfig):
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.

View File

@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -48,12 +48,21 @@ class DiffusionConfig(PreTrainedConfig):
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@@ -64,7 +73,7 @@ class DiffusionConfig(PreTrainedConfig):
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.

View File

@@ -61,8 +61,6 @@ class PI05Config(PreTrainedConfig):
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,

View File

@@ -27,18 +27,18 @@ Usage:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
--reward-model-path pepijn223/sarm_single_uni4
# Faster computation with stride (compute every 5 frames, interpolate the rest)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--stride 5
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 5
@@ -714,12 +714,12 @@ Examples:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
--reward-model-path pepijn223/sarm_single_uni4
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 10
""",

View File

@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
lerobot-train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
@@ -40,7 +40,7 @@ and an action expert.
```bash
lerobot-train \
--policy.type=smolvla \
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone().mean().item()
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
loss_dict["losses_after_rm_padding"] = losses.clone()
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims

View File

@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
@@ -40,12 +40,24 @@ class TDMPCConfig(PreTrainedConfig):
is an alternative to using action repeats. If this is set to more than 1, then we require
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
approach of using multiple steps from the plan is not in the original implementation.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.

View File

@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -46,12 +46,21 @@ class VQBeTConfig(PreTrainedConfig):
current step and additional steps going back).
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
action_chunk_size: Action chunk size of each action prediction token.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.

View File

@@ -44,7 +44,6 @@ from .hil_processor import (
AddTeleopActionAsComplimentaryDataStep,
AddTeleopEventsAsInfoStep,
GripperPenaltyProcessorStep,
GymHILAdapterProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
RewardClassifierProcessorStep,
@@ -88,7 +87,6 @@ __all__ = [
"DoneProcessorStep",
"EnvAction",
"EnvTransition",
"GymHILAdapterProcessorStep",
"GripperPenaltyProcessorStep",
"hotswap_stats",
"IdentityProcessorStep",

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# copy over non-STATE features
for ft, feats in features.items():
if ft != FeatureType.STATE:
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
@@ -100,11 +100,13 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
key=OBS_STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[FeatureType.STATE] = state_feats
new_features[PipelineFeatureType.STATE] = state_feats
return new_features

View File

@@ -20,7 +20,6 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .converters import to_tensor
from .core import EnvAction, EnvTransition, PolicyAction
from .hil_processor import TELEOP_ACTION_KEY
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
@@ -90,13 +89,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
new_transition[TransitionKey.ACTION] = torch_action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if TELEOP_ACTION_KEY in complementary_data:
teleop_action = complementary_data[TELEOP_ACTION_KEY]
if isinstance(teleop_action, EnvAction):
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def transform_features(

View File

@@ -312,40 +312,9 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
return features
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
class GymHILAdapterProcessorStep(ProcessorStep):
"""
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
This step normalizes the `transition` object by:
1. Copying `teleop_action` from `info` to `complementary_data`.
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
info = transition.get(TransitionKey.INFO, {})
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if TELEOP_ACTION_KEY in info:
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
if "is_intervention" in info:
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
transition[TransitionKey.INFO] = info
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ProcessorStep):
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
@@ -360,27 +329,26 @@ class GripperPenaltyProcessorStep(ProcessorStep):
penalty: float = -0.01
max_gripper_pos: float = 30.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
def complementary_data(self, complementary_data: dict) -> dict:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
transition: The incoming environment transition.
complementary_data: The incoming complementary data, which should contain
raw joint positions.
Returns:
The modified transition with the penalty added to complementary data.
A new complementary data dictionary with the `discrete_penalty` key added.
"""
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
action = self.transition.get(TransitionKey.ACTION)
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return new_transition
return complementary_data
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return new_transition
return complementary_data
# Gripper action is a PolicyAction at this stage
gripper_action = action[-1].item()
@@ -396,12 +364,11 @@ class GripperPenaltyProcessorStep(ProcessorStep):
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Update complementary data with penalty info
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
return new_complementary_data
def get_config(self) -> dict[str, Any]:
"""

View File

@@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
Args:
save_directory: The directory where the pipeline will be saved. If None, saves to
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
card_kwargs: Additional arguments passed to the card template to customize the card.
config_filename: The name of the JSON configuration file. If None, a name is

View File

@@ -36,7 +36,6 @@ from lerobot.processor import (
DeviceProcessorStep,
EnvTransition,
GripperPenaltyProcessorStep,
GymHILAdapterProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
MapDeltaActionToRobotActionStep,
@@ -380,7 +379,6 @@ def make_processors(
]
env_pipeline_steps = [
GymHILAdapterProcessorStep(),
Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(),
AddBatchDimensionProcessorStep(),
@@ -414,10 +412,7 @@ def make_processors(
if cfg.processor.observation.add_current_to_observation:
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
add_ee_pose = (
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
)
if kinematics_solver is not None and add_ee_pose:
if kinematics_solver is not None:
env_pipeline_steps.append(
ForwardKinematicsJointsToEEObservation(
kinematics=kinematics_solver,
@@ -440,12 +435,7 @@ def make_processors(
)
# Add gripper penalty processor if gripper config exists and enabled
# Only add if max_gripper_pos is explicitly configured (required for normalization)
if (
cfg.processor.gripper is not None
and cfg.processor.gripper.use_gripper
and cfg.processor.max_gripper_pos is not None
):
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
env_pipeline_steps.append(
GripperPenaltyProcessorStep(
penalty=cfg.processor.gripper.gripper_penalty,
@@ -610,14 +600,7 @@ def control_loop(
dataset = None
if cfg.mode == "record":
if teleop_device:
action_features = teleop_device.action_features
else:
action_features = {
"dtype": "float32",
"shape": (4,),
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
}
action_features = teleop_device.action_features
features = {
ACTION: action_features,
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
@@ -665,7 +648,7 @@ def control_loop(
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Use the new step function
transition = step_env_and_process_transition(
@@ -734,8 +717,6 @@ def control_loop(
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Finalizing dataset before pushing to hub")
dataset.finalize()
logging.info("Pushing dataset to hub")
dataset.push_to_hub()

View File

@@ -26,21 +26,8 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
def cfg_to_group(
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
) -> list[str] | str:
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
def _maybe_truncate(tag: str) -> str:
"""Truncate tag to max_tag_length characters if required.
wandb rejects tags longer than 64 characters.
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
"""
if len(tag) <= max_tag_length:
return tag
return tag[:max_tag_length]
lst = [
f"policy:{cfg.policy.type}",
f"seed:{cfg.seed}",
@@ -49,8 +36,6 @@ def cfg_to_group(
lst.append(f"dataset:{cfg.dataset.repo_id}")
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
if truncate_tags:
lst = [_maybe_truncate(tag) for tag in lst]
return lst if return_list else "-".join(lst)
@@ -98,7 +83,7 @@ class WandBLogger:
entity=self.cfg.entity,
name=self.job_name,
notes=self.cfg.notes,
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
tags=cfg_to_group(cfg, return_list=True),
dir=self.log_dir,
config=cfg.to_dict(),
# TODO(rcadene): try set to True

View File

@@ -19,7 +19,6 @@ from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
@@ -113,7 +112,6 @@ class BiOpenArmFollower(Robot):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -135,7 +133,6 @@ class BiOpenArmFollower(Robot):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
@@ -149,7 +146,6 @@ class BiOpenArmFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(
self,
action: RobotAction,
@@ -174,7 +170,6 @@ class BiOpenArmFollower(Robot):
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()

View File

@@ -19,7 +19,6 @@ from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from .config_bi_so_follower import BiSOFollowerConfig
@@ -97,7 +96,6 @@ class BiSOFollower(Robot):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -118,7 +116,6 @@ class BiSOFollower(Robot):
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
@@ -132,7 +129,6 @@ class BiSOFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
# Remove "left_" prefix
left_action = {
@@ -152,7 +148,6 @@ class BiSOFollower(Robot):
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()

View File

@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")

View File

@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")

View File

@@ -193,7 +193,7 @@ class KochFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")

View File

@@ -360,7 +360,7 @@ class LeKiwi(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")

View File

@@ -176,7 +176,7 @@ class OmxFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")

View File

@@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -119,7 +119,6 @@ class OpenArmFollower(Robot):
"""Check if robot is connected."""
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the robot and optionally calibrate.
@@ -127,6 +126,8 @@ class OpenArmFollower(Robot):
We assume that at connection time, the arms are in a safe rest position,
and torque can be safely disabled to run calibration if needed.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
@@ -218,7 +219,6 @@ class OpenArmFollower(Robot):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_observation(self) -> RobotObservation:
"""
Get current observation from robot including position, velocity, and torque.
@@ -228,6 +228,9 @@ class OpenArmFollower(Robot):
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict: dict[str, Any] = {}
states = self.bus.sync_read_all_states()
@@ -241,7 +244,7 @@ class OpenArmFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
@@ -250,7 +253,6 @@ class OpenArmFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(
self,
action: RobotAction,
@@ -270,6 +272,8 @@ class OpenArmFollower(Robot):
Returns:
The action actually sent (potentially clipped)
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
@@ -329,9 +333,10 @@ class OpenArmFollower(Robot):
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
"""Disconnect from robot."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
self.bus.disconnect(self.config.disable_torque_on_disconnect)

View File

@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
return obs_dict

View File

@@ -40,7 +40,7 @@ class SOFollowerConfig:
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = True
use_degrees: bool = False
@RobotConfig.register_subclass("so101_follower")

View File

@@ -187,7 +187,7 @@ class SOFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")

View File

@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
# Cameras - read images from ZMQ cameras
for cam_name, cam in self._cameras.items():
obs[cam_name] = cam.read_latest()
obs[cam_name] = cam.async_read()
return obs

View File

@@ -47,14 +47,16 @@ local$ rerun lerobot_pusht_episode_0.rrd
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
--grpc-port 9876
--ws-port 9087
local$ rerun rerun+http://IP:GRPC_PORT/proxy
local$ rerun ws://localhost:9087
```
"""
@@ -73,7 +75,6 @@ import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
@@ -92,11 +93,10 @@ def visualize_dataset(
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
grpc_port: int = 9876,
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
display_compressed_images: bool = False,
**kwargs,
) -> Path | None:
if save:
assert output_dir is not None, (
@@ -126,9 +126,7 @@ def visualize_dataset(
gc.collect()
if mode == "distant":
server_uri = rr.serve_grpc(grpc_port=grpc_port)
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
rr.serve_web_viewer(open_browser=False, web_port=web_port)
logging.info("Logging to Rerun")
@@ -228,7 +226,7 @@ def main():
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. "
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
)
parser.add_argument(
@@ -240,13 +238,8 @@ def main():
parser.add_argument(
"--ws-port",
type=int,
help="deprecated, please use --grpc-port instead.",
)
parser.add_argument(
"--grpc-port",
type=int,
default=9876,
help="gRPC port for rerun.io when `--mode distant` is set.",
default=9087,
help="Web socket port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--save",
@@ -272,7 +265,9 @@ def main():
parser.add_argument(
"--display-compressed-images",
action="store_true",
type=bool,
required=True,
default=False,
help="If set, display compressed images in Rerun instead of uncompressed ones.",
)
@@ -282,14 +277,6 @@ def main():
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
if kwargs["ws_port"] is not None:
logging.warning(
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
)
logging.warning("Setting grpc_port to ws_port value.")
kwargs["grpc_port"] = kwargs.pop("ws_port")
init_logging()
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)

View File

@@ -24,112 +24,96 @@ When new_repo_id is specified, creates a new dataset.
Usage Examples:
Delete episodes 0, 2, and 5 from a dataset:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Delete episodes and save to a new dataset:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_filtered \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Split dataset by fractions:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.8, "val": 0.2}'
Split dataset by episode indices:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
Split into more than two splits:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
Merge multiple datasets:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
Remove camera feature:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Pick up the cube and place it"
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Default task" \
--operation.episode_tasks '{"5": "Special task for episode 5"}'
Convert image dataset to video format and save locally:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir /path/to/output/pusht_video
Convert image dataset to video format and save with new repo_id:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video
Convert image dataset to video format and push to hub:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video \
--push_to_hub true
Show dataset information:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features true
Show dataset information without feature details:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features false
Using JSON config file:
lerobot-edit-dataset \
python -m lerobot.scripts.lerobot_edit_dataset \
--config_path path/to/edit_config.json
"""
import abc
import logging
import shutil
import sys
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset,
@@ -145,46 +129,39 @@ from lerobot.utils.utils import init_logging
@dataclass
class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@OperationConfig.register_subclass("delete_episodes")
@dataclass
class DeleteEpisodesConfig(OperationConfig):
class DeleteEpisodesConfig:
type: str = "delete_episodes"
episode_indices: list[int] | None = None
@OperationConfig.register_subclass("split")
@dataclass
class SplitConfig(OperationConfig):
class SplitConfig:
type: str = "split"
splits: dict[str, float | list[int]] | None = None
@OperationConfig.register_subclass("merge")
@dataclass
class MergeConfig(OperationConfig):
class MergeConfig:
type: str = "merge"
repo_ids: list[str] | None = None
@OperationConfig.register_subclass("remove_feature")
@dataclass
class RemoveFeatureConfig(OperationConfig):
class RemoveFeatureConfig:
type: str = "remove_feature"
feature_names: list[str] | None = None
@OperationConfig.register_subclass("modify_tasks")
@dataclass
class ModifyTasksConfig(OperationConfig):
class ModifyTasksConfig:
type: str = "modify_tasks"
new_task: str | None = None
episode_tasks: dict[str, str] | None = None
@OperationConfig.register_subclass("convert_image_to_video")
@dataclass
class ConvertImageToVideoConfig(OperationConfig):
class ConvertImageToVideoConfig:
type: str = "convert_image_to_video"
output_dir: str | None = None
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
@@ -197,17 +174,17 @@ class ConvertImageToVideoConfig(OperationConfig):
max_frames_per_batch: int | None = None
@OperationConfig.register_subclass("info")
@dataclass
class InfoConfig(OperationConfig):
type: str = "info"
show_features: bool = False
@dataclass
class EditDatasetConfig:
repo_id: str
operation: OperationConfig
operation: (
DeleteEpisodesConfig
| SplitConfig
| MergeConfig
| RemoveFeatureConfig
| ModifyTasksConfig
| ConvertImageToVideoConfig
)
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
@@ -456,49 +433,6 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
logging.info("Dataset saved locally (not pushed to hub)")
def _get_dataset_size(repo_path):
import os
total = 0
with os.scandir(repo_path) as it:
for entry in it:
if entry.is_file():
total += entry.stat().st_size
elif entry.is_dir():
total += _get_dataset_size(entry.path)
return total
def handle_info(cfg: EditDatasetConfig):
if not isinstance(cfg.operation, InfoConfig):
raise ValueError("Operation config must be InfoConfig")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
sys.stdout.write(
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
)
sys.stdout.write(
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
)
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
total_file_size = _get_dataset_size(dataset.root)
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
if cfg.operation.show_features:
import json
feature_dump_str = json.dumps(
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
)
sys.stdout.write("Features:\n")
sys.stdout.write(f"{feature_dump_str}\n")
@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
operation_type = cfg.operation.type
@@ -515,11 +449,11 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "info":
handle_info(cfg)
else:
available = ", ".join(OperationConfig.get_known_choices())
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
)
def main() -> None:

View File

@@ -398,14 +398,7 @@ def record_loop(
)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s: float = 1 / fps - dt_s
if sleep_time_s < 0:
logging.warning(
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
precise_sleep(max(sleep_time_s, 0.0))
precise_sleep(max(1 / fps - dt_s, 0.0))
timestamp = time.perf_counter() - start_episode_t

View File

@@ -22,7 +22,7 @@ lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=<USER>/record-test \
--dataset.repo_id=aliberts/record-test \
--dataset.episode=0
```

View File

@@ -45,7 +45,7 @@ from dataclasses import dataclass, field
import draccus
from lerobot.utils.import_utils import _can_available
from lerobot.utils.import_utils import is_package_available
MOTOR_NAMES = {
0x01: "joint_1",
@@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig):
@draccus.wrap()
def setup_can(cfg: CANSetupConfig):
if not _can_available:
if not is_package_available("can"):
print("Error: python-can not installed. Install with: pip install python-can")
sys.exit(1)

View File

@@ -337,13 +337,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
# Filter out episodes - hardcoded list of bad episodes to discard
episodes_to_discard = {
133, 134, 502, 565, 568, 657, 910, 944, 1039, 1209, 1346, 1360, 1379,
1605, 1690, 1790, 2105, 2106, 2122, 2118, 2156, 2575, 2764, 2876, 2925,
3100, 3381, 3405, 3406, 68, 1214, 1456,
}
all_episodes = set(range(dataset.meta.total_episodes))
episodes_to_use = dataset.episodes # May be None (all episodes) or a subset
# If dataset.episodes is already filtered, start from that subset
if episodes_to_use is not None:
episodes_to_use = [ep for ep in episodes_to_use if ep not in episodes_to_discard]
else:
episodes_to_use = sorted(all_episodes - episodes_to_discard)
if hasattr(cfg.policy, "drop_n_last_frames") or episodes_to_use is not None:
shuffle = False
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
episode_indices_to_use=episodes_to_use,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
else:

View File

@@ -19,7 +19,6 @@ from functools import cached_property
from lerobot.processor import RobotAction
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_leader import OpenArmLeader
from ..teleoperator import Teleoperator
@@ -89,7 +88,6 @@ class BiOpenArmLeader(Teleoperator):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -111,7 +109,6 @@ class BiOpenArmLeader(Teleoperator):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
@@ -129,7 +126,6 @@ class BiOpenArmLeader(Teleoperator):
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()

View File

@@ -18,7 +18,7 @@ import logging
from functools import cached_property
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.decorators import check_if_not_connected
from ..so_leader import SOLeader
from ..teleoperator import Teleoperator
@@ -72,7 +72,6 @@ class BiSOLeader(Teleoperator):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -111,7 +110,6 @@ class BiSOLeader(Teleoperator):
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()

View File

@@ -21,7 +21,7 @@ from typing import Any
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_openarm_leader import OpenArmLeaderConfig
@@ -84,7 +84,6 @@ class OpenArmLeader(Teleoperator):
"""Check if teleoperator is connected."""
return self.bus.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the teleoperator.
@@ -92,6 +91,8 @@ class OpenArmLeader(Teleoperator):
For manual control, we disable torque after connecting so the
arm can be moved by hand.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
@@ -182,7 +183,6 @@ class OpenArmLeader(Teleoperator):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_action(self) -> RobotAction:
"""
Get current action from the leader arm.
@@ -193,6 +193,8 @@ class OpenArmLeader(Teleoperator):
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
action_dict: dict[str, Any] = {}
@@ -212,9 +214,10 @@ class OpenArmLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
@check_if_not_connected
def disconnect(self) -> None:
"""Disconnect from teleoperator."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
# For manual control, ensure torque is disabled before disconnecting

View File

@@ -28,7 +28,7 @@ class SOLeaderConfig:
port: str
# Whether to use degrees for angles
use_degrees: bool = True
use_degrees: bool = False
@TeleoperatorConfig.register_subclass("so101_leader")

View File

@@ -16,14 +16,14 @@ import platform
import time
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005):
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
"""
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
Parameters:
- seconds: duration to wait
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
Note:
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.

View File

@@ -390,30 +390,6 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Resize)
output = tf(img_tensor)
assert output.shape[-2:] == (32, 32)
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Identity)
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_make_transform_from_config_invalid_type():
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
with pytest.raises(ValueError, match="not valid"):
make_transform_from_config(tf_cfg)
def test_save_all_transforms(img_tensor_factory, tmp_path):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(enable=True)

View File

@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging.version import Version
from torch.optim.lr_scheduler import LambdaLR
from lerobot.optim.schedulers import (
@@ -40,10 +38,6 @@ def test_diffuser_scheduler(optimizer):
"last_epoch": 1,
"lr_lambdas": [None],
}
if Version(torch.__version__) >= Version("2.8"):
expected_state_dict["_is_initial"] = False
assert scheduler.state_dict() == expected_state_dict
@@ -62,10 +56,6 @@ def test_vqbet_scheduler(optimizer):
"last_epoch": 1,
"lr_lambdas": [None],
}
if Version(torch.__version__) >= Version("2.8"):
expected_state_dict["_is_initial"] = False
assert scheduler.state_dict() == expected_state_dict
@@ -86,10 +76,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
"last_epoch": 1,
"lr_lambdas": [None],
}
if Version(torch.__version__) >= Version("2.8"):
expected_state_dict["_is_initial"] = False
assert scheduler.state_dict() == expected_state_dict

View File

@@ -142,7 +142,6 @@ def _make_reachy2_camera_mock(*args, **kwargs):
cam.connect = MagicMock()
cam.disconnect = MagicMock()
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
return cam

View File

@@ -1,74 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import draccus
import pytest
from lerobot.scripts.lerobot_edit_dataset import (
ConvertImageToVideoConfig,
DeleteEpisodesConfig,
EditDatasetConfig,
InfoConfig,
MergeConfig,
ModifyTasksConfig,
OperationConfig,
RemoveFeatureConfig,
SplitConfig,
)
def parse_cfg(cli_args: list[str]) -> EditDatasetConfig:
"""Helper to parse CLI args into an EditDatasetConfig via draccus."""
return draccus.parse(EditDatasetConfig, args=cli_args)
class TestOperationTypeParsing:
"""Test that --operation.type correctly selects the right config subclass."""
@pytest.mark.parametrize(
"type_name, expected_cls",
[
("delete_episodes", DeleteEpisodesConfig),
("split", SplitConfig),
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
("info", InfoConfig),
],
)
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
assert isinstance(cfg.operation, expected_cls), (
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
)
@pytest.mark.parametrize(
"type_name, expected_cls",
[
("delete_episodes", DeleteEpisodesConfig),
("split", SplitConfig),
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
("info", InfoConfig),
],
)
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
assert resolved_name == type_name