mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
9 Commits
feat/slurm
...
tmp/fold_t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a07ff640bb | ||
|
|
0753415244 | ||
|
|
a054663e38 | ||
|
|
aceb651e40 | ||
|
|
76a4529d29 | ||
|
|
39e14c086c | ||
|
|
0af2029328 | ||
|
|
c027b2971c | ||
|
|
9cc203034e |
8
.github/workflows/full_tests.yml
vendored
8
.github/workflows/full_tests.yml
vendored
@@ -101,11 +101,9 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: |
|
||||
github.repository == 'huggingface/lerobot' && (
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
|
||||
github.event_name == 'push' ||
|
||||
github.event_name == 'workflow_dispatch'
|
||||
)
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
|
||||
github.event_name == 'push' ||
|
||||
github.event_name == 'workflow_dispatch'
|
||||
outputs:
|
||||
image_tag: ${{ steps.set_tag.outputs.image_tag }}
|
||||
env:
|
||||
|
||||
1
.github/workflows/unbound_deps_tests.yml
vendored
1
.github/workflows/unbound_deps_tests.yml
vendored
@@ -91,7 +91,6 @@ jobs:
|
||||
name: Build and Push Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
env:
|
||||
|
||||
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
|
||||
@@ -185,7 +185,7 @@ echo $HF_USER
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
|
||||
@@ -224,7 +224,7 @@ lerobot-record \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
@@ -241,7 +241,7 @@ lerobot-replay \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_camera \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -249,13 +249,13 @@ lerobot-replay \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=<USER>/hand_test_policy
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
@@ -270,7 +270,7 @@ lerobot-record \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# Installation
|
||||
|
||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||
|
||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||
## Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
```
|
||||
|
||||
## Step 2: Environment Setup
|
||||
## Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
|
||||
@@ -40,7 +38,7 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
## Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ policy.type=pi0
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
|
||||
@@ -56,7 +56,7 @@ policy.type=pi05
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
|
||||
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
@@ -288,7 +288,7 @@ lerobot-train \
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
@@ -307,7 +307,7 @@ lerobot-train \
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
@@ -468,7 +468,7 @@ This script:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
|
||||
@@ -216,7 +216,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -266,7 +266,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
|
||||
@@ -12,7 +12,6 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -157,30 +156,6 @@ lerobot-edit-dataset \
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
Show the information of datasets such as number of episode, number of frame, File size and so on.
|
||||
No change will be made to the dataset
|
||||
|
||||
```bash
|
||||
|
||||
# Show dataset information without feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
|
||||
# Show dataset information with feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
@@ -45,7 +45,7 @@ policy.type=wall_x
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
|
||||
@@ -154,7 +154,7 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
@@ -58,16 +58,16 @@ Usage:
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/reuben_pi0 \
|
||||
--dataset.repo_id=<USER>/so101_cube_in_cup \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
@@ -75,8 +75,8 @@ Usage:
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
@@ -84,8 +84,8 @@ Usage:
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
|
||||
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
@@ -41,7 +41,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
@@ -53,7 +53,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
|
||||
@@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
@@ -76,9 +76,9 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
@@ -360,9 +360,9 @@ ignore_errors = false
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.motors.*"
|
||||
ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
|
||||
@@ -13,5 +13,5 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
|
||||
from .configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
from .utils import make_cameras_from_configs
|
||||
|
||||
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -25,10 +25,6 @@ class ColorMode(str, Enum):
|
||||
RGB = "rgb"
|
||||
BGR = "bgr"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
class Cv2Rotation(int, Enum):
|
||||
NO_ROTATION = 0
|
||||
@@ -36,25 +32,6 @@ class Cv2Rotation(int, Enum):
|
||||
ROTATE_180 = 180
|
||||
ROTATE_270 = -90
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
|
||||
class Cv2Backends(int, Enum):
|
||||
ANY = 0
|
||||
V4L2 = 200
|
||||
DSHOW = 700
|
||||
PVAPI = 800
|
||||
ANDROID = 1000
|
||||
AVFOUNDATION = 1200
|
||||
MSMF = 1400
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
|
||||
|
||||
@@ -32,11 +32,10 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..utils import get_cv2_rotation
|
||||
from ..utils import get_cv2_backend, get_cv2_rotation
|
||||
from .configuration_opencv import ColorMode, OpenCVCameraConfig
|
||||
|
||||
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
|
||||
@@ -118,7 +117,7 @@ class OpenCVCamera(Camera):
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
self.backend: int = config.backend
|
||||
self.backend: int = get_cv2_backend()
|
||||
|
||||
if self.height and self.width:
|
||||
self.capture_width, self.capture_height = self.width, self.height
|
||||
@@ -133,7 +132,6 @@ class OpenCVCamera(Camera):
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
@@ -150,6 +148,8 @@ class OpenCVCamera(Camera):
|
||||
ConnectionError: If the specified camera index/path is not found or fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
# Use 1 thread for OpenCV operations to avoid potential conflicts or
|
||||
# blocking in multi-threaded applications, especially during data collection.
|
||||
@@ -178,7 +178,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@check_if_not_connected
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
|
||||
@@ -198,6 +197,8 @@ class OpenCVCamera(Camera):
|
||||
to the requested value.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
@@ -347,7 +348,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -374,6 +374,9 @@ class OpenCVCamera(Camera):
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -487,7 +490,6 @@ class OpenCVCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -510,6 +512,8 @@ class OpenCVCamera(Camera):
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
@@ -529,8 +533,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
@@ -545,6 +548,8 @@ class OpenCVCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@@ -50,7 +50,6 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
|
||||
warmup_s: Time reading frames before returning from connect (in seconds)
|
||||
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
|
||||
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
@@ -63,12 +62,22 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
fourcc: str | None = None
|
||||
backend: Cv2Backends = Cv2Backends.ANY
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
self.rotation = Cv2Rotation(self.rotation)
|
||||
self.backend = Cv2Backends(self.backend)
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.rotation not in (
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
|
||||
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
|
||||
raise ValueError(
|
||||
|
||||
@@ -74,4 +74,7 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
@@ -32,7 +32,6 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
@@ -124,7 +123,6 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -138,6 +136,9 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -183,7 +184,6 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Same as read()
|
||||
@@ -197,11 +197,12 @@ class Reachy2Camera(Camera):
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
@@ -218,6 +219,8 @@ class Reachy2Camera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.latest_frame is None or self.latest_timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
@@ -230,7 +233,6 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return self.latest_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
@@ -238,6 +240,8 @@ class Reachy2Camera(Camera):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
@@ -30,8 +30,7 @@ try:
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
@@ -153,7 +152,6 @@ class RealSenseCamera(Camera):
|
||||
"""Checks if the camera pipeline is started and streams are active."""
|
||||
return self.rs_pipeline is not None and self.rs_profile is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the RealSense camera specified in the configuration.
|
||||
@@ -171,6 +169,8 @@ class RealSenseCamera(Camera):
|
||||
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
|
||||
RuntimeError: If the pipeline starts but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
self.rs_pipeline = rs.pipeline()
|
||||
rs_config = rs.config()
|
||||
@@ -290,7 +290,6 @@ class RealSenseCamera(Camera):
|
||||
if self.use_depth:
|
||||
rs_config.enable_stream(rs.stream.depth)
|
||||
|
||||
@check_if_not_connected
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""Sets fps, width, and height from device stream if not already configured.
|
||||
|
||||
@@ -300,6 +299,8 @@ class RealSenseCamera(Camera):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If device is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
|
||||
|
||||
if self.rs_profile is None:
|
||||
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
|
||||
@@ -319,7 +320,6 @@ class RealSenseCamera(Camera):
|
||||
self.width, self.height = actual_width, actual_height
|
||||
self.capture_width, self.capture_height = actual_width, actual_height
|
||||
|
||||
@check_if_not_connected
|
||||
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (depth) synchronously from the camera.
|
||||
@@ -345,6 +345,9 @@ class RealSenseCamera(Camera):
|
||||
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -371,7 +374,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
@@ -401,6 +403,9 @@ class RealSenseCamera(Camera):
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -529,7 +534,6 @@ class RealSenseCamera(Camera):
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame data (color) asynchronously.
|
||||
@@ -552,6 +556,8 @@ class RealSenseCamera(Camera):
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread died unexpectedly or another error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
@@ -572,8 +578,7 @@ class RealSenseCamera(Camera):
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
@@ -588,6 +593,8 @@ class RealSenseCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -60,8 +60,20 @@ class RealSenseCameraConfig(CameraConfig):
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
self.rotation = Cv2Rotation(self.rotation)
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.rotation not in (
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
|
||||
values = (self.fps, self.width, self.height)
|
||||
if any(v is not None for v in values) and any(v is None for v in values):
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
@@ -67,3 +68,14 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
|
||||
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return int(cv2.CAP_ANY)
|
||||
|
||||
@@ -34,8 +34,7 @@ import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
@@ -105,7 +104,6 @@ class ZMQCamera(Camera):
|
||||
"""Checks if the ZMQ socket is initialized and connected."""
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server.
|
||||
|
||||
@@ -113,6 +111,8 @@ class ZMQCamera(Camera):
|
||||
warmup (bool): If True, waits for the camera to provide at least one
|
||||
valid frame before returning. Defaults to True.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
@@ -211,7 +211,6 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -229,6 +228,9 @@ class ZMQCamera(Camera):
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -299,7 +301,6 @@ class ZMQCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -316,6 +317,8 @@ class ZMQCamera(Camera):
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread is not running.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
@@ -332,7 +335,6 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
@@ -348,6 +350,8 @@ class ZMQCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -32,7 +32,10 @@ class ZMQCameraConfig(CameraConfig):
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.timeout_ms <= 0:
|
||||
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
|
||||
|
||||
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
@@ -72,10 +72,11 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
# TODO: Temporarily disabled for merging datasets with different features (e.g. shirt_id)
|
||||
# if features != meta.features:
|
||||
# raise ValueError(
|
||||
# f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
# )
|
||||
|
||||
return fps, robot_type, features
|
||||
|
||||
|
||||
@@ -563,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
@@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
@@ -1572,7 +1572,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
root: str | Path | None = None,
|
||||
robot_type: str | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
|
||||
@@ -216,17 +216,16 @@ class ImageTransformsConfig:
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "SharpnessJitter":
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
|
||||
transform_cls = getattr(v2, cfg.type, None)
|
||||
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
|
||||
return transform_cls(**cfg.kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Transform '{cfg.type}' is not valid. It must be a class in "
|
||||
f"torchvision.transforms.v2 or 'SharpnessJitter'."
|
||||
)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
|
||||
@@ -122,9 +122,19 @@ def load_nested_dataset(
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -529,7 +529,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
@@ -205,7 +205,6 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -112,7 +112,6 @@ class LiberoEnv(gym.Env):
|
||||
visualization_height: int = 480,
|
||||
init_states: bool = True,
|
||||
episode_index: int = 0,
|
||||
n_envs: int = 1,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
@@ -146,9 +145,7 @@ class LiberoEnv(gym.Env):
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
|
||||
|
||||
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
@@ -298,8 +295,7 @@ class LiberoEnv(gym.Env):
|
||||
self._env.seed(seed)
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
|
||||
self.init_state_id += self._reset_stride # Change init_state_id when reset
|
||||
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
@@ -377,7 +373,6 @@ def _make_env_fns(
|
||||
init_states=init_states,
|
||||
episode_length=episode_length,
|
||||
episode_index=episode_index,
|
||||
n_envs=n_envs,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -221,7 +221,7 @@ class RangeFinderGUI:
|
||||
|
||||
self.bus = bus
|
||||
self.groups = groups if groups is not None else {"all": list(bus.motors)}
|
||||
self.group_names = list(self.groups)
|
||||
self.group_names = list(groups)
|
||||
self.current_group = self.group_names[0]
|
||||
|
||||
if not bus.is_connected:
|
||||
@@ -230,20 +230,18 @@ class RangeFinderGUI:
|
||||
self.calibration = bus.read_calibration()
|
||||
self.res_table = bus.model_resolution_table
|
||||
self.present_cache = {
|
||||
m: bus.read("Present_Position", m, normalize=False)
|
||||
for motors in self.groups.values()
|
||||
for m in motors
|
||||
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
|
||||
}
|
||||
|
||||
pygame.init()
|
||||
self.font = pygame.font.Font(None, FONT_SIZE)
|
||||
|
||||
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
|
||||
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
|
||||
self.label_pad = label_pad
|
||||
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
|
||||
self.controls_bottom = 10 + SAVE_H
|
||||
self.base_y = self.controls_bottom + TOP_GAP
|
||||
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
|
||||
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
|
||||
|
||||
self.screen = pygame.display.set_mode((width, height))
|
||||
pygame.display.set_caption("Motors range finder")
|
||||
|
||||
@@ -23,7 +23,6 @@ from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
@@ -37,6 +36,7 @@ else:
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
@@ -155,7 +155,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
@@ -163,6 +162,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
@@ -208,9 +211,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
logger.info("Starting handshake with motors...")
|
||||
|
||||
# Drain any pending messages
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
while self.canbus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
@@ -246,7 +246,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
)
|
||||
logger.info("Handshake successful. All motors ready.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
@@ -254,6 +253,8 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
@@ -282,10 +283,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
@@ -344,10 +341,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
@@ -363,10 +356,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
@@ -405,13 +394,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses: dict[int, can.Message] = {}
|
||||
responses = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
# 100us poll timeout
|
||||
@@ -475,9 +461,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
@@ -505,9 +488,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Step 1: Send all MIT control commands
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
@@ -582,9 +562,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
@@ -614,7 +595,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
return mapping[data_name]
|
||||
|
||||
@check_if_not_connected
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -625,6 +605,8 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Write a value to a single motor. Positions are always in degrees.
|
||||
Can write 'Goal_Position', 'Kp', or 'Kd'.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
self._gains[motor][data_name.lower()] = float(value)
|
||||
@@ -674,10 +656,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def _batch_refresh(self, motors: list[str]) -> None:
|
||||
"""Internal helper to refresh a list of motors and update cache."""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Send refresh commands
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
@@ -700,12 +678,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""
|
||||
Write values to multiple motors simultaneously. Positions are always in degrees.
|
||||
"""
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
key = data_name.lower()
|
||||
for motor, val in values.items():
|
||||
@@ -714,8 +690,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
elif data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
@@ -758,9 +732,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self,
|
||||
motors: str | list[str] | None = None,
|
||||
motors: NameOrID | list[NameOrID] | None = None,
|
||||
display_values: bool = True,
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
|
||||
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
for motor, m in self.motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=int(drive_modes[motor]),
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
drive_mode=drive_modes[motor],
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
On Dynamixel Motors:
|
||||
Present_Position = Actual_Position + Homing_Offset
|
||||
"""
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
half_turn_homings = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
return None
|
||||
return
|
||||
|
||||
return {id_: data[0] for id_, data in data_list.items()}
|
||||
|
||||
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
# HACK: monkeypatch
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
||||
self.port_handler, scs.PortHandler
|
||||
)
|
||||
self.packet_handler = scs.PacketHandler(protocol_version)
|
||||
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
On Feetech Motors:
|
||||
Present_Position = Actual_Position - Homing_Offset
|
||||
"""
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
half_turn_homings = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
return half_turn_homings
|
||||
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 1, num_retry=num_retry)
|
||||
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
|
||||
import scservo_sdk as scs
|
||||
|
||||
data_list: dict[int, int] = {}
|
||||
data_list = {}
|
||||
|
||||
status_length = 6
|
||||
|
||||
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
return None
|
||||
return
|
||||
|
||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||
if ids_errors:
|
||||
|
||||
@@ -23,7 +23,6 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -94,7 +93,7 @@ class MotorsBusBase(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@@ -180,16 +179,15 @@ class Motor:
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
is_open: bool
|
||||
baudrate: int
|
||||
packet_start_time: float
|
||||
packet_timeout: float
|
||||
tx_time_per_byte: float
|
||||
is_using: bool
|
||||
port_name: str
|
||||
ser: serial.Serial
|
||||
|
||||
def __init__(self, port_name: str) -> None: ...
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool
|
||||
self.port_name: str
|
||||
self.ser: serial.Serial
|
||||
|
||||
def openPort(self): ...
|
||||
def closePort(self): ...
|
||||
@@ -242,22 +240,19 @@ class PacketHandler(Protocol):
|
||||
def regWriteTxRx(self, port, id, address, length, data): ...
|
||||
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
|
||||
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
|
||||
def broadcastPing(self, port): ...
|
||||
|
||||
|
||||
class GroupSyncRead(Protocol):
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
last_result: bool
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.last_result: bool
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -270,17 +265,15 @@ class GroupSyncRead(Protocol):
|
||||
|
||||
|
||||
class GroupSyncWrite(Protocol):
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id, data): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -407,7 +400,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motor_model(self, motor: NameOrID) -> str:
|
||||
def _get_motor_model(self, motor: NameOrID) -> int:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor].model
|
||||
elif isinstance(motor, int):
|
||||
@@ -415,19 +408,17 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
if motors is None:
|
||||
return list(self.motors)
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, int):
|
||||
return [self._id_to_name(motors)]
|
||||
elif isinstance(motors, Sequence):
|
||||
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors.copy()
|
||||
else:
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
@@ -649,19 +640,18 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors.
|
||||
|
||||
Args:
|
||||
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
|
||||
Defaults to `None`.
|
||||
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
def torque_disabled(self, motors: int | str | list[str] | None = None):
|
||||
"""Context-manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
@@ -738,19 +728,24 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
|
||||
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
|
||||
"""Restore factory calibration for the selected motors.
|
||||
|
||||
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
|
||||
The in-memory :pyattr:`calibration` is cleared.
|
||||
|
||||
Args:
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
resets every motor.
|
||||
"""
|
||||
motor_names = self._get_motors_list(motors)
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
for motor in motor_names:
|
||||
for motor in motors:
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
self.write("Homing_Offset", motor, 0, normalize=False)
|
||||
@@ -759,9 +754,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
self.calibration = {}
|
||||
|
||||
def set_half_turn_homings(
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None
|
||||
) -> dict[NameOrID, Value]:
|
||||
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
|
||||
"""Centre each motor range around its current position.
|
||||
|
||||
The function computes and writes a homing offset such that the present position becomes exactly one
|
||||
@@ -771,12 +764,17 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
|
||||
|
||||
Returns:
|
||||
dict[str, Value]: Mapping *motor name → written homing offset*.
|
||||
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
|
||||
"""
|
||||
motor_names = self._get_motors_list(motors)
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
self.reset_calibration(motor_names)
|
||||
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
self.reset_calibration(motors)
|
||||
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
homing_offsets = self._get_half_turn_homings(actual_positions)
|
||||
for motor, offset in homing_offsets.items():
|
||||
self.write("Homing_Offset", motor, offset)
|
||||
@@ -788,8 +786,8 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""Interactively record the min/max encoder values of each motor.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions. Press
|
||||
@@ -801,25 +799,30 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
display_values (bool, optional): When `True` (default) a live table is printed to the console.
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
extreme values observed for each motor.
|
||||
"""
|
||||
motor_names = self._get_motors_list(motors)
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
user_pressed_enter = False
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
|
||||
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
|
||||
|
||||
if display_values:
|
||||
print("\n-------------------------------------------")
|
||||
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
|
||||
for motor in motor_names:
|
||||
for motor in motors:
|
||||
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
|
||||
|
||||
if enter_pressed():
|
||||
@@ -827,9 +830,9 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motor_names) + 3)
|
||||
move_cursor_up(len(motors) + 3)
|
||||
|
||||
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
|
||||
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
|
||||
if same_min_max:
|
||||
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
|
||||
|
||||
@@ -952,12 +955,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
else:
|
||||
return None
|
||||
return
|
||||
if self._is_error(error):
|
||||
if raise_on_error:
|
||||
raise RuntimeError(self.packet_handler.getRxPacketError(error))
|
||||
else:
|
||||
return None
|
||||
return
|
||||
|
||||
return model_number
|
||||
|
||||
@@ -1004,13 +1007,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
decoded = self._decode_sign(data_name, {id_: value})
|
||||
id_value = self._decode_sign(data_name, {id_: value})
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
normalized = self._normalize(decoded)
|
||||
return normalized[id_]
|
||||
id_value = self._normalize(id_value)
|
||||
|
||||
return decoded[id_]
|
||||
return id_value[id_]
|
||||
|
||||
def _read(
|
||||
self,
|
||||
@@ -1021,7 +1023,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int, int]:
|
||||
) -> tuple[int, int]:
|
||||
if length == 1:
|
||||
read_fn = self.packet_handler.read1ByteTxRx
|
||||
elif length == 2:
|
||||
@@ -1071,14 +1073,13 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
model = self.motors[motor].model
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_value = int(value)
|
||||
if normalize and data_name in self.normalized_data:
|
||||
int_value = self._unnormalize({id_: value})[id_]
|
||||
value = self._unnormalize({id_: value})[id_]
|
||||
|
||||
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
|
||||
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _write(
|
||||
self,
|
||||
@@ -1112,7 +1113,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: NameOrID | Sequence[NameOrID] | None = None,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
@@ -1121,7 +1122,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
Args:
|
||||
data_name (str): Register name.
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
normalize (bool, optional): Normalisation flag. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
|
||||
@@ -1142,17 +1143,16 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||
raw_ids_values, _ = self._sync_read(
|
||||
ids_values, _ = self._sync_read(
|
||||
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
decoded = self._decode_sign(data_name, raw_ids_values)
|
||||
ids_values = self._decode_sign(data_name, ids_values)
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
normalized = self._normalize(decoded)
|
||||
return {self._id_to_name(id_): value for id_, value in normalized.items()}
|
||||
ids_values = self._normalize(ids_values)
|
||||
|
||||
return {self._id_to_name(id_): value for id_, value in decoded.items()}
|
||||
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||
|
||||
def _sync_read(
|
||||
self,
|
||||
@@ -1224,24 +1224,21 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
|
||||
raw_ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in raw_ids_values]
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
if self._has_different_ctrl_tables:
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = next(iter(models))
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
|
||||
if normalize and data_name in self.normalized_data:
|
||||
int_ids_values = self._unnormalize(raw_ids_values)
|
||||
ids_values = self._unnormalize(ids_values)
|
||||
|
||||
int_ids_values = self._encode_sign(data_name, int_ids_values)
|
||||
ids_values = self._encode_sign(data_name, ids_values)
|
||||
|
||||
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
|
||||
self._sync_write(
|
||||
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _sync_write(
|
||||
self,
|
||||
|
||||
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- Either:
|
||||
@@ -48,12 +48,21 @@ class ACTConfig(PreTrainedConfig):
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
|
||||
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -48,12 +48,21 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -64,7 +73,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
|
||||
@@ -61,8 +61,6 @@ class PI05Config(PreTrainedConfig):
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
|
||||
@@ -27,18 +27,18 @@ Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
@@ -714,12 +714,12 @@ Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
|
||||
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -40,7 +40,7 @@ and an action expert.
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone().mean().item()
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
|
||||
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
@@ -40,12 +40,24 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
|
||||
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -46,12 +46,21 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
||||
@@ -44,7 +44,6 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -88,7 +87,6 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != FeatureType.STATE:
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
@@ -100,11 +100,13 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# add our new flattened state
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
key=OBS_STATE,
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
)
|
||||
|
||||
new_features[FeatureType.STATE] = state_feats
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -90,13 +89,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -312,40 +312,9 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
@@ -360,27 +329,26 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
|
||||
Returns:
|
||||
The modified transition with the penalty added to complementary data.
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return new_transition
|
||||
return complementary_data
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return new_transition
|
||||
return complementary_data
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
gripper_action = action[-1].item()
|
||||
@@ -396,12 +364,11 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Update complementary data with penalty info
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_transition
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
Args:
|
||||
save_directory: The directory where the pipeline will be saved. If None, saves to
|
||||
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
|
||||
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`.
|
||||
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
|
||||
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
|
||||
card_kwargs: Additional arguments passed to the card template to customize the card.
|
||||
config_filename: The name of the JSON configuration file. If None, a name is
|
||||
|
||||
@@ -36,7 +36,6 @@ from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
@@ -380,7 +379,6 @@ def make_processors(
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -414,10 +412,7 @@ def make_processors(
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
add_ee_pose = (
|
||||
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
|
||||
)
|
||||
if kinematics_solver is not None and add_ee_pose:
|
||||
if kinematics_solver is not None:
|
||||
env_pipeline_steps.append(
|
||||
ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -440,12 +435,7 @@ def make_processors(
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
# Only add if max_gripper_pos is explicitly configured (required for normalization)
|
||||
if (
|
||||
cfg.processor.gripper is not None
|
||||
and cfg.processor.gripper.use_gripper
|
||||
and cfg.processor.max_gripper_pos is not None
|
||||
):
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
@@ -610,14 +600,7 @@ def control_loop(
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
if teleop_device:
|
||||
action_features = teleop_device.action_features
|
||||
else:
|
||||
action_features = {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
|
||||
}
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -665,7 +648,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -734,8 +717,6 @@ def control_loop(
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
dataset.finalize()
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
@@ -26,21 +26,8 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
|
||||
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
|
||||
) -> list[str] | str:
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
|
||||
def _maybe_truncate(tag: str) -> str:
|
||||
"""Truncate tag to max_tag_length characters if required.
|
||||
|
||||
wandb rejects tags longer than 64 characters.
|
||||
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
|
||||
"""
|
||||
if len(tag) <= max_tag_length:
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
@@ -49,8 +36,6 @@ def cfg_to_group(
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
if truncate_tags:
|
||||
lst = [_maybe_truncate(tag) for tag in lst]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
@@ -98,7 +83,7 @@ class WandBLogger:
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
|
||||
@@ -19,7 +19,6 @@ from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
@@ -113,7 +112,6 @@ class BiOpenArmFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -135,7 +133,6 @@ class BiOpenArmFollower(Robot):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
@@ -149,7 +146,6 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
@@ -174,7 +170,6 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -19,7 +19,6 @@ from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_so_follower import BiSOFollowerConfig
|
||||
@@ -97,7 +96,6 @@ class BiSOFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -118,7 +116,6 @@ class BiSOFollower(Robot):
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
@@ -132,7 +129,6 @@ class BiSOFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
@@ -152,7 +148,6 @@ class BiSOFollower(Robot):
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class KochFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ class LeKiwi(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class OmxFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -119,7 +119,6 @@ class OpenArmFollower(Robot):
|
||||
"""Check if robot is connected."""
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the robot and optionally calibrate.
|
||||
@@ -127,6 +126,8 @@ class OpenArmFollower(Robot):
|
||||
We assume that at connection time, the arms are in a safe rest position,
|
||||
and torque can be safely disabled to run calibration if needed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
@@ -218,7 +219,6 @@ class OpenArmFollower(Robot):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
@@ -228,6 +228,9 @@ class OpenArmFollower(Robot):
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
states = self.bus.sync_read_all_states()
|
||||
@@ -241,7 +244,7 @@ class OpenArmFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
@@ -250,7 +253,6 @@ class OpenArmFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
@@ -270,6 +272,8 @@ class OpenArmFollower(Robot):
|
||||
Returns:
|
||||
The action actually sent (potentially clipped)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -329,9 +333,10 @@ class OpenArmFollower(Robot):
|
||||
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
"""Disconnect from robot."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Disconnect CAN bus
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
|
||||
@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class SOFollowerConfig:
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = True
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so101_follower")
|
||||
|
||||
@@ -187,7 +187,7 @@ class SOFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
obs[cam_name] = cam.read_latest()
|
||||
obs[cam_name] = cam.async_read()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
@@ -47,14 +47,16 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--grpc-port 9876
|
||||
--ws-port 9087
|
||||
|
||||
local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
local$ rerun ws://localhost:9087
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -73,7 +75,6 @@ import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
@@ -92,11 +93,10 @@ def visualize_dataset(
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
grpc_port: int = 9876,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert output_dir is not None, (
|
||||
@@ -126,9 +126,7 @@ def visualize_dataset(
|
||||
gc.collect()
|
||||
|
||||
if mode == "distant":
|
||||
server_uri = rr.serve_grpc(grpc_port=grpc_port)
|
||||
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
@@ -228,7 +226,7 @@ def main():
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -240,13 +238,8 @@ def main():
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
help="deprecated, please use --grpc-port instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpc-port",
|
||||
type=int,
|
||||
default=9876,
|
||||
help="gRPC port for rerun.io when `--mode distant` is set.",
|
||||
default=9087,
|
||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
@@ -272,7 +265,9 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
"--display-compressed-images",
|
||||
action="store_true",
|
||||
type=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
help="If set, display compressed images in Rerun instead of uncompressed ones.",
|
||||
)
|
||||
|
||||
@@ -282,14 +277,6 @@ def main():
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
if kwargs["ws_port"] is not None:
|
||||
logging.warning(
|
||||
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
|
||||
)
|
||||
logging.warning("Setting grpc_port to ws_port value.")
|
||||
kwargs["grpc_port"] = kwargs.pop("ws_port")
|
||||
|
||||
init_logging()
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
|
||||
@@ -24,112 +24,96 @@ When new_repo_id is specified, creates a new dataset.
|
||||
Usage Examples:
|
||||
|
||||
Delete episodes 0, 2, and 5 from a dataset:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by episode indices:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
|
||||
|
||||
Split into more than two splits:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Merge multiple datasets:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Remove camera feature:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and push to hub:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
Show dataset information without feature details:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Using JSON config file:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
convert_image_to_video_dataset,
|
||||
@@ -145,46 +129,39 @@ from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("delete_episodes")
|
||||
@dataclass
|
||||
class DeleteEpisodesConfig(OperationConfig):
|
||||
class DeleteEpisodesConfig:
|
||||
type: str = "delete_episodes"
|
||||
episode_indices: list[int] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("split")
|
||||
@dataclass
|
||||
class SplitConfig(OperationConfig):
|
||||
class SplitConfig:
|
||||
type: str = "split"
|
||||
splits: dict[str, float | list[int]] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("merge")
|
||||
@dataclass
|
||||
class MergeConfig(OperationConfig):
|
||||
class MergeConfig:
|
||||
type: str = "merge"
|
||||
repo_ids: list[str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("remove_feature")
|
||||
@dataclass
|
||||
class RemoveFeatureConfig(OperationConfig):
|
||||
class RemoveFeatureConfig:
|
||||
type: str = "remove_feature"
|
||||
feature_names: list[str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("modify_tasks")
|
||||
@dataclass
|
||||
class ModifyTasksConfig(OperationConfig):
|
||||
class ModifyTasksConfig:
|
||||
type: str = "modify_tasks"
|
||||
new_task: str | None = None
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("convert_image_to_video")
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
class ConvertImageToVideoConfig:
|
||||
type: str = "convert_image_to_video"
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
@@ -197,17 +174,17 @@ class ConvertImageToVideoConfig(OperationConfig):
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
type: str = "info"
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: OperationConfig
|
||||
operation: (
|
||||
DeleteEpisodesConfig
|
||||
| SplitConfig
|
||||
| MergeConfig
|
||||
| RemoveFeatureConfig
|
||||
| ModifyTasksConfig
|
||||
| ConvertImageToVideoConfig
|
||||
)
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
@@ -456,49 +433,6 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
def _get_dataset_size(repo_path):
|
||||
import os
|
||||
|
||||
total = 0
|
||||
with os.scandir(repo_path) as it:
|
||||
for entry in it:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += _get_dataset_size(entry.path)
|
||||
return total
|
||||
|
||||
|
||||
def handle_info(cfg: EditDatasetConfig):
|
||||
if not isinstance(cfg.operation, InfoConfig):
|
||||
raise ValueError("Operation config must be InfoConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
|
||||
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
|
||||
sys.stdout.write(
|
||||
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(
|
||||
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
|
||||
|
||||
total_file_size = _get_dataset_size(dataset.root)
|
||||
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
|
||||
if cfg.operation.show_features:
|
||||
import json
|
||||
|
||||
feature_dump_str = json.dumps(
|
||||
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
|
||||
)
|
||||
sys.stdout.write("Features:\n")
|
||||
sys.stdout.write(f"{feature_dump_str}\n")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -515,11 +449,11 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
available = ", ".join(OperationConfig.get_known_choices())
|
||||
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -398,14 +398,7 @@ def record_loop(
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
MOTOR_NAMES = {
|
||||
0x01: "joint_1",
|
||||
@@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig):
|
||||
|
||||
@draccus.wrap()
|
||||
def setup_can(cfg: CANSetupConfig):
|
||||
if not _can_available:
|
||||
if not is_package_available("can"):
|
||||
print("Error: python-can not installed. Install with: pip install python-can")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -337,13 +337,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
# Filter out episodes - hardcoded list of bad episodes to discard
|
||||
episodes_to_discard = {
|
||||
133, 134, 502, 565, 568, 657, 910, 944, 1039, 1209, 1346, 1360, 1379,
|
||||
1605, 1690, 1790, 2105, 2106, 2122, 2118, 2156, 2575, 2764, 2876, 2925,
|
||||
3100, 3381, 3405, 3406, 68, 1214, 1456,
|
||||
}
|
||||
all_episodes = set(range(dataset.meta.total_episodes))
|
||||
episodes_to_use = dataset.episodes # May be None (all episodes) or a subset
|
||||
# If dataset.episodes is already filtered, start from that subset
|
||||
if episodes_to_use is not None:
|
||||
episodes_to_use = [ep for ep in episodes_to_use if ep not in episodes_to_discard]
|
||||
else:
|
||||
episodes_to_use = sorted(all_episodes - episodes_to_discard)
|
||||
|
||||
if hasattr(cfg.policy, "drop_n_last_frames") or episodes_to_use is not None:
|
||||
shuffle = False
|
||||
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
episode_indices_to_use=episodes_to_use,
|
||||
drop_n_last_frames=drop_n_last,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -19,7 +19,6 @@ from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..openarm_leader import OpenArmLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -89,7 +88,6 @@ class BiOpenArmLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -111,7 +109,6 @@ class BiOpenArmLeader(Teleoperator):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
@@ -129,7 +126,6 @@ class BiOpenArmLeader(Teleoperator):
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -72,7 +72,6 @@ class BiSOLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -111,7 +110,6 @@ class BiSOLeader(Teleoperator):
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_leader import OpenArmLeaderConfig
|
||||
@@ -84,7 +84,6 @@ class OpenArmLeader(Teleoperator):
|
||||
"""Check if teleoperator is connected."""
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the teleoperator.
|
||||
@@ -92,6 +91,8 @@ class OpenArmLeader(Teleoperator):
|
||||
For manual control, we disable torque after connecting so the
|
||||
arm can be moved by hand.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
@@ -182,7 +183,6 @@ class OpenArmLeader(Teleoperator):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get current action from the leader arm.
|
||||
@@ -193,6 +193,8 @@ class OpenArmLeader(Teleoperator):
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
action_dict: dict[str, Any] = {}
|
||||
|
||||
@@ -212,9 +214,10 @@ class OpenArmLeader(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from teleoperator."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Disconnect CAN bus
|
||||
# For manual control, ensure torque is disabled before disconnecting
|
||||
|
||||
@@ -28,7 +28,7 @@ class SOLeaderConfig:
|
||||
port: str
|
||||
|
||||
# Whether to use degrees for angles
|
||||
use_degrees: bool = True
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("so101_leader")
|
||||
|
||||
@@ -16,14 +16,14 @@ import platform
|
||||
import time
|
||||
|
||||
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005):
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
|
||||
"""
|
||||
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
|
||||
|
||||
Parameters:
|
||||
- seconds: duration to wait
|
||||
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
|
||||
|
||||
Note:
|
||||
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
|
||||
|
||||
@@ -390,30 +390,6 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
assert isinstance(tf, v2.Resize)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape[-2:] == (32, 32)
|
||||
|
||||
|
||||
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
assert isinstance(tf, v2.Identity)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_make_transform_from_config_invalid_type():
|
||||
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
|
||||
with pytest.raises(ValueError, match="not valid"):
|
||||
make_transform_from_config(tf_cfg)
|
||||
|
||||
|
||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.optim.schedulers import (
|
||||
@@ -40,10 +38,6 @@ def test_diffuser_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -62,10 +56,6 @@ def test_vqbet_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -86,10 +76,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -142,7 +142,6 @@ def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.lerobot_edit_dataset import (
|
||||
ConvertImageToVideoConfig,
|
||||
DeleteEpisodesConfig,
|
||||
EditDatasetConfig,
|
||||
InfoConfig,
|
||||
MergeConfig,
|
||||
ModifyTasksConfig,
|
||||
OperationConfig,
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
)
|
||||
|
||||
|
||||
def parse_cfg(cli_args: list[str]) -> EditDatasetConfig:
|
||||
"""Helper to parse CLI args into an EditDatasetConfig via draccus."""
|
||||
return draccus.parse(EditDatasetConfig, args=cli_args)
|
||||
|
||||
|
||||
class TestOperationTypeParsing:
|
||||
"""Test that --operation.type correctly selects the right config subclass."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type_name, expected_cls",
|
||||
[
|
||||
("delete_episodes", DeleteEpisodesConfig),
|
||||
("split", SplitConfig),
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
||||
assert isinstance(cfg.operation, expected_cls), (
|
||||
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type_name, expected_cls",
|
||||
[
|
||||
("delete_episodes", DeleteEpisodesConfig),
|
||||
("split", SplitConfig),
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
||||
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
|
||||
assert resolved_name == type_name
|
||||
Reference in New Issue
Block a user