Compare commits

..

3 Commits

Author SHA1 Message Date
Steven Palma
224be5be9a Merge branch 'main' into feat/add_macos_ci 2025-10-14 18:52:04 +02:00
Steven Palma
67269e33a5 ci: add more env flags 2025-10-08 17:14:20 +02:00
Steven Palma
66936f278f feat(ci): add macos runner testing 2025-10-08 14:58:55 +02:00
160 changed files with 1958 additions and 19224 deletions

View File

@@ -57,7 +57,11 @@ jobs:
# It runs everytime we commit to a PR or push to main
fast-pytest-tests:
name: Fast Pytest Tests
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
env:
MUJOCO_GL: egl
steps:
@@ -67,12 +71,21 @@ jobs:
lfs: true
# TODO(Steven): Evaluate the need of these dependencies
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential git \
curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
# Add ffmpeg@7 paths for subsequent steps
echo "PATH=/opt/homebrew/opt/ffmpeg@7/bin:$PATH" >> $GITHUB_ENV
echo "LDFLAGS=-L/opt/homebrew/opt/ffmpeg@7/lib" >> $GITHUB_ENV
echo "CPPFLAGS=-I/opt/homebrew/opt/ffmpeg@7/include" >> $GITHUB_ENV
echo "PKG_CONFIG_PATH=/opt/homebrew/opt/ffmpeg@7/lib/pkgconfig" >> $GITHUB_ENV
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:

View File

@@ -51,7 +51,11 @@ jobs:
# It runs everytime a PR is approved or a push to main
full-tests:
name: Full Tests
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
if: |
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved') ||
github.event_name == 'push' ||
@@ -64,11 +68,16 @@ jobs:
lfs: true
persist-credentials: false
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
@@ -78,7 +87,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install lerobot with all extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
run: uv sync --all-extras
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10

View File

@@ -119,7 +119,6 @@ jobs:
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
@@ -159,36 +158,3 @@ jobs:
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
run: make test-end-to-end
# This job runs multi-GPU training tests with 4 GPUs
nightly-multi-gpu-tests:
name: Nightly Multi-GPU Tests
needs: [build-docker-gpu-nightly]
runs-on:
group: aws-g4dn-12xlarge # Instance with 4 GPUs
env:
HF_HOME: /home/user_lerobot/.cache/huggingface
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
CUDA_VISIBLE_DEVICES: "0,1,2,3"
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
defaults:
run:
shell: bash
working-directory: /lerobot
steps:
- name: Verify GPU availability
run: |
nvidia-smi
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10

View File

@@ -82,14 +82,6 @@ jobs:
exit 1
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
run: python -m pip install build
@@ -111,7 +103,7 @@ jobs:
- name: Publish to TestPyPI for pre-releases
# True for tags like 'v0.2.0-rc1'
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
repository-url: https://test.pypi.org/legacy/
verbose: true
@@ -119,7 +111,7 @@ jobs:
- name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
verbose: true
print-hash: true
@@ -128,7 +120,11 @@ jobs:
test-release:
name: Test Release
needs: [build-and-publish]
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
permissions:
contents: read
env:
@@ -138,15 +134,20 @@ jobs:
with:
lfs: true
persist-credentials: false
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:
enable-cache: true # zizmor: ignore[cache-poisoning]
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
- name: Create uv virtual environment

View File

@@ -27,17 +27,15 @@ env:
This issue was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
CLOSE_PR_MESSAGE: >
This PR was closed because it has been stalled for 21 days with no activity.
This PR was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
WARN_ISSUE_MESSAGE: >
This issue has been automatically marked as stale because it has not had
recent activity (6 months). It will be closed if no further activity occurs.
Any change, comment or update to this issue will reset this count.
Thank you for your contributions.
WARN_PR_MESSAGE: >
This PR has been automatically marked as stale because it has not had
recent activity (1 year). It will be closed if no further activity occurs.
Any change, comment or update to this PR will reset this count.
recent activity (6 months). It will be closed if no further activity occurs.
Thank you for your contributions.
jobs:
@@ -58,10 +56,10 @@ jobs:
stale-pr-label: stale
exempt-issue-labels: never-stale
exempt-pr-labels: never-stale
days-before-issue-stale: 180
days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
days-before-issue-close: 14
days-before-pr-stale: 365
days-before-pr-close: 21
days-before-pr-stale: 180
days-before-pr-close: 14
delete-branch: true
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}

View File

@@ -42,7 +42,11 @@ jobs:
# This job runs the E2E tests + pytest with all unbound extras
full-tests:
name: Full Unbound Tests
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
env:
MUJOCO_GL: egl
steps:
@@ -51,11 +55,16 @@ jobs:
lfs: true
persist-credentials: false
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]

View File

@@ -26,7 +26,7 @@ repos:
##### General Code Quality & Formatting #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
rev: v5.0.0
hooks:
- id: check-added-large-files
args: ['--maxkb=1024']
@@ -39,20 +39,20 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.1
rev: v0.12.4
hooks:
- id: ruff-format
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.38.1
rev: v1.34.0
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
rev: v3.20.0
hooks:
- id: pyupgrade
args: [--py310-plus]
@@ -68,12 +68,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.28.0
rev: v8.27.2
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.15.2
rev: v1.11.0
hooks:
- id: zizmor
@@ -87,7 +87,7 @@ repos:
# TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing #####
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
rev: v1.16.0
hooks:
- id: mypy
args: [--config-file=pyproject.toml]

View File

@@ -137,7 +137,7 @@ Follow these steps to start contributing:
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda:
Set up a development environment with conda or miniconda:
```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev

View File

@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
### Environment Setup
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
When using `conda`, install `ffmpeg` in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -185,11 +185,6 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install libero or pi tags, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
### Weights & Biases
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
@@ -212,13 +207,13 @@ lerobot-dataset-viz \
--episode-index 0
```
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--mode local \
--local-files-only 1 \
--episode-index 0
```
@@ -315,7 +310,7 @@ To upload these to the hub, run the following:
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
```
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
### Acknowledgment
@@ -342,3 +337,7 @@ If you want, you can cite this work with:
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline)
```
```

View File

@@ -17,8 +17,6 @@
title: Train RL in Simulation
- local: async
title: Use Async Inference
- local: multi_gpu_training
title: Multi GPU training
title: "Tutorials"
- sections:
- local: lerobot-dataset-v3
@@ -37,8 +35,6 @@
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
- local: groot
title: NVIDIA GR00T N1.5
title: "Policies"
- sections:
- local: il_sim

View File

@@ -1,122 +0,0 @@
# GR00T N1.5 Policy
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
This document outlines the specifics of its integration and usage within the LeRobot framework.
## Model Overview
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
- Real captured data from robots.
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
- Internet-scale video data.
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
## Installation Requirements
As of today, GR00T N1.5 requires flash attention for it's internal working.
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
```bash
# Check https://pytorch.org/get-started/locally/ for your system
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
3. Install LeRobot by running:
```bash
pip install lerobot[groot] # consider also installing libero,dev and test tags
```
## Usage
To use GR00T in your LeRobot configuration, specify the policy type as:
```python
policy.type=groot
```
## Training
### Training Command Example
Here's a complete training command for finetuning the base GR00T model on your own dataset:
```bash
# Using a multi-GPU setup
accelerate launch \
--multi_gpu \
--num_processes=$NUM_GPUS \
$(which lerobot-train) \
--output_dir=$OUTPUT_DIR \
--save_checkpoint=true \
--batch_size=$BATCH_SIZE \
--steps=$NUM_STEPS \
--save_freq=$SAVE_FREQ \
--log_freq=$LOG_FREQ \
--policy.push_to_hub=true \
--policy.type=groot \
--policy.repo_id=$REPO_ID \
--policy.tune_diffusion_model=false \
--dataset.repo_id=$DATASET_ID \
--wandb.enable=true \
--wandb.disable_artifact=true \
--job_name=$JOB_NAME
```
## Performance Results
### Libero Benchmark Results
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
| **Libero Spatial** | 82.0% | 92.0% |
| **Libero Object** | 99.0% | 92.0% |
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
```bash
lerobot-record \
--robot.type=bi_so100_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
--robot.id=bimanual_follower \
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm"
--policy.path=<user>/groot-bimanual # your trained model
--dataset.episode_time_s=30
--dataset.reset_time_s=10
```
## License
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).

View File

@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(hf auth whoami | head -n 1)
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```

View File

@@ -1,15 +1,8 @@
# Installation
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Environment Setup
Create a virtual environment with Python 3.10, using conda:
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
```bash
conda create -y -n lerobot python=3.10
@@ -21,7 +14,7 @@ Then activate your conda environment, you have to do this each time you open a s
conda activate lerobot
```
When using `conda`, install `ffmpeg` in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -81,9 +74,6 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.

View File

@@ -208,36 +208,34 @@ LeRobot supports saving and loading calibration data automatically. This is usef
<!-- prettier-ignore-start -->
```python
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
```
<!-- prettier-ignore-end -->
> @property
> def is_calibrated(self) -> bool:
> return True
>
> def calibrate(self) -> None:
> pass
> ```
### `is_calibrated`
This should reflect whether your robot has the required calibration loaded.
<!-- prettier-ignore-start -->
```python
```
<!-- prettier-ignore-end -->python
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
```
<!-- prettier-ignore-end -->
### `calibrate()`
The goal of the calibration is twofold:
- Know the physical range of motion of each motors in order to only send commands within this range.
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
- Know the physical range of motion of each motors in order to only send commands within this range.
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
<!-- prettier-ignore-start -->
```python
def calibrate(self) -> None:

View File

@@ -1,125 +0,0 @@
# Multi-GPU Training
This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate).
## Installation
First, ensure you have accelerate installed:
```bash
pip install accelerate
```
## Training with Multiple GPUs
You can launch training in two ways:
### Option 1: Without config (specify parameters directly)
You can specify all parameters directly in the command without running `accelerate config`:
```bash
accelerate launch \
--multi_gpu \
--num_processes=2 \
$(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=act \
--policy.repo_id=${HF_USER}/my_trained_policy \
--output_dir=outputs/train/act_multi_gpu \
--job_name=act_multi_gpu \
--wandb.enable=true
```
**Key accelerate parameters:**
- `--multi_gpu`: Enable multi-GPU training
- `--num_processes=2`: Number of GPUs to use
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
### Option 2: Using accelerate config
If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running:
```bash
accelerate config
```
This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings:
- Compute environment: This machine
- Number of machines: 1
- Number of processes: (number of GPUs you want to use)
- GPU ids to use: (leave empty to use all)
- Mixed precision: fp16 or bf16 (recommended for faster training)
Then launch training with:
```bash
accelerate launch $(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=act \
--policy.repo_id=${HF_USER}/my_trained_policy \
--output_dir=outputs/train/act_multi_gpu \
--job_name=act_multi_gpu \
--wandb.enable=true
```
## How It Works
When you launch training with accelerate:
1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate
2. **Data distribution**: Your batch is automatically split across GPUs
3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation
4. **Single process logging**: Only the main process logs to wandb and saves checkpoints
## Learning Rate and Training Steps Scaling
**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters.
### Why No Automatic Scaling?
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
However, LeRobot keeps the learning rate exactly as you specify it.
### When and How to Scale
If you want to scale your hyperparameters when using multiple GPUs, you should do it manually:
**Learning Rate Scaling:**
```bash
# Example: 2 GPUs with linear LR scaling
# Base LR: 1e-4, with 2 GPUs -> 2e-4
accelerate launch --num_processes=2 $(which lerobot-train) \
--optimizer.lr=2e-4 \
--dataset.repo_id=lerobot/pusht \
--policy=act
```
**Training Steps Scaling:**
Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
```bash
# Example: 2 GPUs with effective batch size 2x larger
# Original: batch_size=8, steps=100000
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
accelerate launch --num_processes=2 $(which lerobot-train) \
--batch_size=8 \
--steps=50000 \
--dataset.repo_id=lerobot/pusht \
--policy=act
```
## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).

View File

@@ -1,328 +0,0 @@
# OpenArms Robot
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
## Hardware Overview
- **7 DOF per arm** (14 DOF total for dual arm setup)
- **1 gripper per arm** (2 grippers total)
- **Damiao motors** with 4 different types:
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
- **DM4340** for shoulder rotation and elbow (J3, J4)
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
- **24V power supply** required
- **CAN interface device**:
- **Linux**: Any SocketCAN-compatible adapter
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
- Proper CAN wiring (CANH, CANL, 120Ω termination)
## Motor Configuration
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|-------|-------|------------|---------------|-------------|-------------|
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
## Quick Start
```bash
# Install system dependencies
sudo apt install can-utils iproute2
# Install LeRobot with OpenArms support
pip install -e ".[openarms]"
```
## Setup Guide
### Step 1: Motor ID Configuration
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
Key points:
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
- Use the Damiao Debugging Tools to set these IDs
### Step 2: Setup CAN Interface
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
#### Linux (SocketCAN)
```bash
# Find your CAN interface
ip link show
# Configure can0, 1, 2, 3
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000
sudo ip link set can0 up
sudo ip link set can1 down
sudo ip link set can1 type can bitrate 1000000
sudo ip link set can1 up
sudo ip link set can2 down
sudo ip link set can2 type can bitrate 1000000
sudo ip link set can2 up
sudo ip link set can3 down
sudo ip link set can3 type can bitrate 1000000
sudo ip link set can3 up
# Verify configuration
ip link show can0
```
or run:
`examples/openarms/setup_can.sh`
### Testing canbus and motor connection
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
## Usage
### Basic Setup
```python
from lerobot.robots.openarms import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
# Configure for dual arm setup
config = OpenArmsFollowerConfig(
port="can0",
can_interface="socketcan", # Or "auto" for auto-detection
id="openarms_dual",
is_dual_arm=True,
)
robot = OpenArmsFollower(config)
robot.connect()
```
### Calibration
On first use, you'll need to calibrate the robot:
```python
robot.calibrate()
```
The calibration process will:
1. Disable torque on all motors
2. Ask you to position arms in **hanging position with grippers closed**
3. Set this as the zero position
4. Ask you to move each joint through its full range
5. Record min/max positions for each joint
6. Save calibration to file
### Reading Observations
The robot provides comprehensive state information:
```python
observation = robot.get_observation()
# Observation includes for each motor:
# - {motor_name}.pos: Position in degrees
# - {motor_name}.vel: Velocity in degrees/second
# - {motor_name}.torque: Motor torque
# - {camera_name}: Camera images (if configured)
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
```
### Sending Actions
```python
# Send target positions (in degrees)
action = {
"right_joint_1.pos": 45.0,
"right_joint_2.pos": -30.0,
# ... all joints
"right_gripper.pos": 45.0, # Half-closed
}
actual_action = robot.send_action(action)
```
### Gripper Control
```python
# Open gripper
robot.open_gripper(arm="right")
# Close gripper
robot.close_gripper(arm="right")
```
## Safety Features
### 1. Maximum Relative Target
Limits how far a joint can move in a single command to prevent sudden movements:
```python
config = OpenArmsFollowerConfig(
port="can0",
# Limit all joints to 10 degrees per command
max_relative_target=10.0,
# Or set per-motor limits
max_relative_target={
"right_joint_1": 15.0, # Slower moving joint
"right_joint_2": 10.0,
"right_gripper": 5.0, # Very slow gripper
}
)
```
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
### 2. Torque Limits
Control maximum torque output, especially important for grippers and teleoperation:
```python
config = OpenArmsFollowerConfig(
port="can0",
# Gripper torque limit (fraction of motor's max torque)
gripper_torque_limit=0.5, # 50% of max torque
)
```
Lower torque limits prevent damage when gripping delicate objects.
### 3. MIT Control Gains
Control responsiveness and stability via PID-like gains:
```python
config = OpenArmsFollowerConfig(
port="can0",
position_kp=10.0, # Position gain (higher = more responsive)
position_kd=0.5, # Velocity damping (higher = more damped)
)
```
**Guidelines**:
- **For following (robot)**: Higher gains for responsiveness
- `position_kp=10.0`, `position_kd=0.5`
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
- `manual_control=True` (torque disabled)
### 4. Velocity Limits
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
- Max velocity: 30 rad/s ≈ 1718°/s
The motors will automatically limit velocity to safe values.
## Teleoperation
### Leader Arm Setup
The leader arm is moved manually (torque disabled) to generate commands:
```python
from lerobot.teleoperators.openarms import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
config = OpenArmsLeaderConfig(
port="can1", # Separate CAN interface for leader
id="openarms_leader",
manual_control=True, # Torque disabled for manual movement
is_dual_arm=True,
)
leader = OpenArmsLeader(config)
leader.connect()
# Read current position as action
action = leader.get_action()
# action contains positions for all joints in degrees
```
### Safety Considerations for Teleoperation
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
2. **Enable max_relative_target** on follower to smooth abrupt movements
3. **Lower torque limits** on follower to prevent damage from tracking errors
4. **Test with one arm** before enabling dual arm teleoperation
5. **Have emergency stop** ready (power switch or CAN disable)
```python
# Recommended follower config for teleoperation
follower_config = OpenArmsFollowerConfig(
port="can0",
max_relative_target=5.0, # Small steps for smooth following
gripper_torque_limit=0.3, # Low torque for safety
position_kp=5.0, # Lower gains for gentler following
position_kd=0.3,
)
```
## Troubleshooting
### Motor Shaking/Unstable
- **Lower control gains**: Reduce `position_kp` and `position_kd`
- **Check calibration**: Re-run calibration procedure
- **Verify power**: Insufficient current can cause instability
- **Check mechanical**: Loose connections, binding, or damaged components
### CAN Bus Errors
```bash
# Check for errors
ip -s link show can0
# Reset CAN interface
sudo ip link set can0 down
sudo ip link set can0 up
```
### Control Mode
OpenArms uses **MIT control mode** which allows simultaneous control of:
- Position (degrees)
- Velocity (degrees/second)
- Torque (N·m)
- Position gain (Kp)
- Velocity damping (Kd)
### Communication
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
- **Frame format**: Standard 11-bit IDs
- **Update rate**: Typically 50-100 Hz depending on motor count
- **Latency**: ~10-20ms per motor command
## References
- [OpenArm Official Documentation](https://docs.openarm.dev/)
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
- [Enactic GitHub](https://github.com/enactic/openarm_can)

View File

@@ -1,27 +0,0 @@
## Research Paper
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
## Repository
Code: https://github.com/NVIDIA/Isaac-GR00T
## Citation
```bibtex
@inproceedings{gr00tn1_2025,
archivePrefix = {arxiv},
eprint = {2503.14734},
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
month = {March},
year = {2025},
booktitle = {ArXiv Preprint},
}
```
## Additional Resources
Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B

View File

@@ -132,15 +132,17 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
if __name__ == "__main__":
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break

View File

@@ -1,416 +0,0 @@
#!/usr/bin/env python3
"""
Comprehensive debug script for OpenArms CAN FD communication.
Tests all 4 CAN interfaces with CAN FD support.
"""
import can
import time
import sys
import subprocess
def check_can_interface(port):
"""Check if CAN interface is UP and configured."""
try:
result = subprocess.run(['ip', 'link', 'show', port],
capture_output=True, text=True)
if result.returncode != 0:
return False, "Interface not found", None
output = result.stdout
if 'UP' not in output:
return False, "Interface is DOWN", None
# Check if CAN FD is enabled
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
return True, "Interface is UP", is_fd
except FileNotFoundError:
return None, "Cannot check (ip command not found)", None
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
"""
Test a single motor and return all responses.
Returns:
list of (arbitration_id, data) tuples for all responses received
"""
# Send enable command
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=use_fd
)
try:
bus.send(enable_msg)
except Exception as e:
return None, f"Send error: {e}"
# Listen for responses
responses = []
start_time = time.time()
while time.time() - start_time < timeout:
msg = bus.recv(timeout=0.1)
if msg:
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
# Send disable command
disable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
is_extended_id=False,
is_fd=use_fd
)
try:
bus.send(disable_msg)
except:
pass
return responses, None
def test_interface(port, interface_type="socketcan", use_can_fd=True):
"""Test all 8 motors on a single CAN interface."""
results = {
'interface': port,
'status': None,
'is_fd': use_can_fd,
'motors': {}
}
# Check interface status
status_ok, status_msg, interface_has_fd = check_can_interface(port)
if interface_has_fd is not None:
results['interface_fd_enabled'] = interface_has_fd
if use_can_fd and not interface_has_fd:
status_msg += " (CAN FD NOT enabled on interface!)"
elif interface_has_fd:
status_msg += " (CAN FD enabled)"
results['status'] = status_msg
if status_ok is False:
return results
# Try to connect
try:
if use_can_fd:
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
bus = can.interface.Bus(
channel=port,
interface=interface_type,
bitrate=1000000,
data_bitrate=5000000,
fd=True
)
else:
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
bus = can.interface.Bus(
channel=port,
interface=interface_type,
bitrate=1000000
)
except Exception as e:
results['status'] = f"Connection failed: {e}"
return results
try:
# Clear any pending messages
while bus.recv(timeout=0.01):
pass
# Test each motor (0x01 to 0x08)
for motor_id in range(0x01, 0x09):
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
if error:
results['motors'][motor_id] = {'error': error}
elif responses:
results['motors'][motor_id] = {
'found': True,
'responses': responses
}
else:
results['motors'][motor_id] = {
'found': False,
'responses': []
}
time.sleep(0.05) # Small delay between motors
finally:
bus.shutdown()
return results
def print_results(all_results):
"""Print formatted results for all interfaces."""
print("SUMMARY - Motors Found on Each Interface")
motor_names = {
0x01: "joint_1 (Shoulder pan)",
0x02: "joint_2 (Shoulder lift)",
0x03: "joint_3 (Shoulder rotation)",
0x04: "joint_4 (Elbow flex)",
0x05: "joint_5 (Wrist roll)",
0x06: "joint_6 (Wrist pitch)",
0x07: "joint_7 (Wrist rotation)",
0x08: "gripper",
}
total_found = 0
for result in all_results:
interface = result['interface']
status = result['status']
print(f"{interface}: {status}")
if result.get('is_fd'):
print(f" Mode: CAN FD")
else:
print(f" Mode: CAN 2.0")
if 'Connection failed' in status or 'DOWN' in status:
print(f" ⚠ Cannot test {interface}")
continue
motors_found = 0
for motor_id in range(0x01, 0x09):
motor_data = result['motors'].get(motor_id, {})
motor_name = motor_names.get(motor_id, "Unknown")
if motor_data.get('error'):
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
elif motor_data.get('found'):
motors_found += 1
total_found += 1
responses = motor_data['responses']
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
for resp_id, data, is_fd in responses:
data_hex = data.hex()
fd_flag = " [FD]" if is_fd else " [2.0]"
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
else:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
# Overall summary
print("OVERALL SUMMARY")
print(f"Total motors found across all interfaces: {total_found}")
# Analyze configuration
print("DIAGNOSIS")
for result in all_results:
interface = result['interface']
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
if motors_found == 0:
print(f"\n{interface}: NO MOTORS FOUND")
print(" Possible issues:")
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
print(" 4. CANH/CANL wiring issue")
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
# Check FD mismatch
if result.get('is_fd') and not result.get('interface_fd_enabled'):
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
elif motors_found < 8:
print(f"\n{interface}: Only {motors_found}/8 motors responding")
print(" Check power and connections for missing motors")
else:
print(f"\n{interface}: All 8 motors responding correctly!")
# Check for unexpected response IDs
print("RESPONSE ID ANALYSIS")
for result in all_results:
interface = result['interface']
unexpected = []
for motor_id, motor_data in result['motors'].items():
if motor_data.get('found'):
expected_id = motor_id + 0x10
actual_ids = [resp[0] for resp in motor_data['responses']]
if expected_id not in actual_ids:
unexpected.append((motor_id, actual_ids))
if unexpected:
print(f"\n{interface}: Unexpected response IDs detected")
for motor_id, actual_ids in unexpected:
expected_id = motor_id + 0x10
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
f"got {[f'0x{id:02X}' for id in actual_ids]}")
print(" → Motor Master IDs need reconfiguration")
else:
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
if motors_found > 0:
print(f"\n{interface}: All responding motors use correct IDs")
def test_communication_speed(interface, motor_id, num_iterations=100):
"""
Test communication speed with a motor.
Returns:
tuple: (hz, avg_latency_ms) or (None, None) if test failed
"""
try:
# Connect to interface
bus = can.interface.Bus(
channel=interface,
interface="socketcan",
bitrate=1000000,
data_bitrate=5000000,
fd=True
)
# Send refresh commands and measure round-trip time
latencies = []
successful = 0
for _ in range(num_iterations):
start = time.perf_counter()
# Send enable command (lightweight operation)
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=True
)
bus.send(enable_msg)
# Wait for response
msg = bus.recv(timeout=0.1)
if msg:
latency = (time.perf_counter() - start) * 1000 # Convert to ms
latencies.append(latency)
successful += 1
bus.shutdown()
if successful > 0:
avg_latency = sum(latencies) / len(latencies)
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
return hz, avg_latency
return None, None
except Exception as e:
print(f" Speed test error: {e}")
return None, None
def main():
"""Main function to test all CAN interfaces with CAN FD."""
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
print("Testing motors 0x01-0x08 on each interface")
print()
print("Make sure:")
print(" ✓ Motors are powered (24V)")
print(" ✓ CAN interfaces configured with FD mode:")
print(" ./examples/openarms/setup_can.sh")
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
print()
input("Press ENTER to start testing...")
# Test all 4 interfaces with CAN FD
all_results = []
for i in range(4):
interface = f"can{i}"
print(f"Testing {interface}...")
result = test_interface(interface, use_can_fd=True)
all_results.append(result)
# Quick status
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
print(f"{interface}: {result['status']}")
else:
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
print(f" {interface}: {motors_found}/8 motors found")
time.sleep(0.2)
# Print detailed results
print_results(all_results)
print("Testing Complete!")
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
if all_found == 0:
print("\n⚠️ CRITICAL: No motors found on any interface!")
print("\nTop issues to check:")
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
print(" 3. Missing termination resistors")
print("\nTry:")
print(" a) Check motor parameters with Damiao Debugging Tools")
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
else:
# Run speed test on interfaces with motors
print("COMMUNICATION SPEED TEST")
print("\nTesting maximum communication frequency...")
for result in all_results:
interface = result['interface']
# Find first responding motor
responding_motor = None
for motor_id, motor_data in result['motors'].items():
if motor_data.get('found'):
responding_motor = motor_id
break
if responding_motor:
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
if hz:
print(f" ✓ Max frequency: {hz:.1f} Hz")
print(f" ✓ Avg latency: {latency:.2f} ms")
print(f" ✓ Commands per second: ~{int(hz)}")
else:
print(f" ✗ Speed test failed")
else:
print(f"\n{interface}: No motors found, skipping speed test")
print()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\nTesting interrupted by user.")
sys.exit(1)
except Exception as e:
print(f"\nUnexpected error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,360 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenArms Policy Evaluation
Evaluates a trained policy on the OpenArms robot by running inference and recording
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
Example usage:
python examples/openarms/evaluate.py
"""
import time
from pathlib import Path
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.processor import make_default_processors
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval7" # TODO: Replace with your eval dataset name
TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task, this should match!!
NUM_EPISODES = 1
FPS = 30
EPISODE_TIME_SEC = 300
RESET_TIME_SEC = 60
# Robot CAN interfaces
FOLLOWER_LEFT_PORT = "can0"
FOLLOWER_RIGHT_PORT = "can1"
# If enabled, you can manually reset the environment between evaluation episodes
USE_LEADER_FOR_RESETS = True # Set to False if you don't want to use leader
LEADER_LEFT_PORT = "can2"
LEADER_RIGHT_PORT = "can3"
# Camera configuration
CAMERA_CONFIG = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=FPS),
}
def main():
"""Main evaluation function."""
print("OpenArms Policy Evaluation")
print(f"\nModel: {HF_MODEL_ID}")
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
print(f"Task: {TASK_DESCRIPTION}")
print(f"Episodes: {NUM_EPISODES}")
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
print(f"Reset Duration: {RESET_TIME_SEC}s")
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
follower_config = OpenArmsFollowerConfig(
port_left=FOLLOWER_LEFT_PORT,
port_right=FOLLOWER_RIGHT_PORT,
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
cameras=CAMERA_CONFIG,
)
follower = OpenArmsFollower(follower_config)
follower.connect(calibrate=False)
if not follower.is_connected:
raise RuntimeError("Follower robot failed to connect!")
leader = None
if USE_LEADER_FOR_RESETS:
leader_config = OpenArmsLeaderConfig(
port_left=LEADER_LEFT_PORT,
port_right=LEADER_RIGHT_PORT,
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
leader = OpenArmsLeader(leader_config)
leader.connect(calibrate=False)
if not leader.is_connected:
raise RuntimeError("Leader robot failed to connect!")
# Enable gravity compensation
if leader.pin_robot is not None:
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
else:
print(f"Leader connected but gravity compensation unavailable (no URDF)")
# Build default processors for action and observation
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Build dataset features from robot features and processors
# For actions, only include positions (no velocity or torque)
action_features_hw = {}
for key, value in follower.action_features.items():
if key.endswith(".pos"):
action_features_hw[key] = value
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=follower.observation_features),
use_videos=True,
),
)
# Check if dataset already exists
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
if dataset_path.exists():
print(f"Evaluation dataset already exists at: {dataset_path}")
print("This will append new episodes to the existing dataset.")
choice = input(" Continue? (y/n): ").strip().lower()
if choice != 'y':
print(" Aborting evaluation.")
follower.disconnect()
if leader:
leader.disconnect()
return
# Create dataset
dataset = LeRobotDataset.create(
repo_id=HF_EVAL_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=follower.name,
use_videos=True,
image_writer_processes=0,
image_writer_threads=12,
)
# Load policy config from pretrained model and create policy using factory
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
policy = make_policy(policy_config, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
preprocessor_overrides={
"device_processor": {"device": str(policy.config.device)}
},
)
print(f"\nRunning evaluation...")
# Initialize keyboard listener and visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="openarms_evaluation")
episode_idx = 0
try:
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
print(f"\nRunning inference for episode {episode_idx + 1}...")
# Run inference with policy
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Handle re-recording
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
dataset.save_episode()
episode_idx += 1
# Reset environment between episodes (if not last episode)
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
if USE_LEADER_FOR_RESETS and leader:
log_say("Reset the environment using leader arms")
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
# Use leader for manual reset with gravity compensation
import numpy as np
dt = 1 / FPS
reset_start_time = time.perf_counter()
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
if events["exit_early"] or events["stop_recording"]:
break
loop_start = time.perf_counter()
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity and friction torques
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=1.0
)
# Combine torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply compensation
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor, kp=0.0, kd=kd,
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor, kp=0.0, kd=kd,
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower
follower_action = {}
for joint in leader_positions_deg.keys():
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
if follower_action:
follower.send_action(follower_action)
# Maintain loop rate
loop_duration = time.perf_counter() - loop_start
sleep_time = dt - loop_duration
if sleep_time > 0:
time.sleep(sleep_time)
print("Reset complete")
else:
log_say("Waiting for manual reset")
print(f"Manually reset the environment and press ENTER to continue")
input("Press ENTER when ready...")
print(f"Evaluation complete! {episode_idx} episodes recorded")
log_say("Evaluation complete", blocking=True)
except KeyboardInterrupt:
print("\n\nEvaluation interrupted by user")
finally:
if leader:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
if listener is not None:
listener.stop()
dataset.finalize()
print("\nUploading to Hugging Face Hub...")
dataset.push_to_hub(private=True)
if __name__ == "__main__":
main()

View File

@@ -1,216 +0,0 @@
import time
import numpy as np
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
# Friction model parameters from OpenArms config/follower.yaml
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
FRICTION_PARAMS = {
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
}
# Constants from OpenArms C++ implementation
AMP_TMP = 1.0
COEF_TMP = 0.1
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
"""
Compute friction torque for a single motor using the tanh friction model.
Args:
velocity_rad_per_sec: Angular velocity in rad/s
motor_index: Index of the motor (0-7)
Returns:
Friction torque in N·m (scaled for stability)
"""
Fc = FRICTION_PARAMS["Fc"][motor_index]
k = FRICTION_PARAMS["k"][motor_index]
Fv = FRICTION_PARAMS["Fv"][motor_index]
Fo = FRICTION_PARAMS["Fo"][motor_index]
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
friction_torque = (
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
Fv * velocity_rad_per_sec +
Fo
)
# Scale down friction compensation for stability at lower control rates
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
friction_torque *= FRICTION_SCALE
return friction_torque
def main() -> None:
config = OpenArmsFollowerConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0,
)
print("Initializing robot...")
follower = OpenArmsFollower(config)
follower.connect(calibrate=True)
print(f"Applying friction compensation")
print(" 1. Support the arm before starting")
print(" 2. The arm will be held in place by friction compensation")
print(" 3. You should be able to move it with gentle force")
print("\nPress ENTER when ready to start...")
input()
print(f"✓ Motors enabled")
print("\nStarting friction compensation loop...")
print("Press Ctrl+C to stop\n")
loop_times = []
last_print_time = time.perf_counter()
# Motor name to index mapping
motor_name_to_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
try:
while True:
loop_start = time.perf_counter()
# Get current joint positions and velocities from robot
obs = follower.get_observation()
# Extract velocities in degrees per second
velocities_deg_per_sec = {}
positions_deg = {}
for motor in follower.bus_right.motors:
vel_key = f"right_{motor}.vel"
pos_key = f"right_{motor}.pos"
if vel_key in obs:
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
if pos_key in obs:
positions_deg[f"right_{motor}"] = obs[pos_key]
for motor in follower.bus_left.motors:
vel_key = f"left_{motor}.vel"
pos_key = f"left_{motor}.pos"
if vel_key in obs:
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
if pos_key in obs:
positions_deg[f"left_{motor}"] = obs[pos_key]
# Convert velocities to rad/s and compute friction torques
friction_torques_nm = {}
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
# Extract motor name without arm prefix
if motor_full_name.startswith("right_"):
motor_name = motor_full_name.removeprefix("right_")
elif motor_full_name.startswith("left_"):
motor_name = motor_full_name.removeprefix("left_")
else:
continue
# Get motor index for friction parameters
motor_index = motor_name_to_index.get(motor_name, 0)
# Convert velocity to rad/s
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
# Compute friction torque
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
friction_torques_nm[motor_full_name] = friction_torque
# Apply friction compensation to right arm (all joints INCLUDING gripper)
for motor in follower.bus_right.motors:
full_name = f"right_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = friction_torques_nm.get(full_name, 0.0)
# Get motor index for damping gain
motor_index = motor_name_to_index.get(motor, 0)
kd = DAMPING_KD[motor_index]
# Send MIT control command with friction compensation + damping
follower.bus_right._mit_control(
motor=motor,
kp=0.0, # No position control
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Apply friction compensation to left arm (all joints INCLUDING gripper)
for motor in follower.bus_left.motors:
full_name = f"left_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = friction_torques_nm.get(full_name, 0.0)
# Get motor index for damping gain
motor_index = motor_name_to_index.get(motor, 0)
kd = DAMPING_KD[motor_index]
# Send MIT control command with friction compensation + damping
follower.bus_left._mit_control(
motor=motor,
kp=0.0, # No position control
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print status every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz")
loop_times = []
last_print_time = loop_end
time.sleep(0.001)
except KeyboardInterrupt:
print("\n\nStopping friction compensation...")
finally:
print("\nDisabling all motors and disconnecting...")
follower.bus_right.disable_torque()
follower.bus_left.disable_torque()
time.sleep(0.1)
follower.disconnect()
print("✓ Safe shutdown complete")
if __name__ == "__main__":
main()

View File

@@ -1,142 +0,0 @@
import time
import numpy as np
import pinocchio as pin
from os.path import join, dirname, exists, expanduser
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
def main() -> None:
config = OpenArmsFollowerConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0,
)
print("Initializing robot...")
follower = OpenArmsFollower(config)
follower.connect(calibrate=True)
# Load URDF for Pinocchio dynamics
urdf_path = "/home/croissant/Documents/openarm_description/openarm_bimanual_pybullet.urdf"
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
pin_robot.data = pin_robot.model.createData()
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
follower.pin_robot = pin_robot
print(f"Applying gravity compensation")
print(" 1. Support the arm before starting")
print(" 2. The arm will be held in place by gravity compensation")
print(" 3. You should be able to move it with gentle force")
print("\nPress ENTER when ready to start...")
input()
print(f"✓ Motors enabled")
print("\nStarting gravity compensation loop...")
print("Press Ctrl+C to stop\n")
loop_times = []
last_print_time = time.perf_counter()
try:
while True:
loop_start = time.perf_counter()
# Get current joint positions from robot
obs = follower.get_observation()
# Extract positions in degrees
positions_deg = {}
for motor in follower.bus_right.motors:
key = f"right_{motor}.pos"
if key in obs:
positions_deg[f"right_{motor}"] = obs[key]
for motor in follower.bus_left.motors:
key = f"left_{motor}.pos"
if key in obs:
positions_deg[f"left_{motor}"] = obs[key]
# Convert to radians and calculate gravity torques
# Use the built-in method from OpenArmsFollower
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
torques_nm = follower._gravity_from_q(positions_rad)
# Apply gravity compensation to right arm (all joints except gripper)
for motor in follower.bus_right.motors:
if motor == "gripper":
continue # Skip gripper
full_name = f"right_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = torques_nm.get(full_name, 0.0)
# Send MIT control command with gravity compensation torque
follower.bus_right._mit_control(
motor=motor,
kp=0.0, # No position control
kd=0.0, # No velocity damping
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Apply gravity compensation to left arm (all joints except gripper)
for motor in follower.bus_left.motors:
if motor == "gripper":
continue # Skip gripper
full_name = f"left_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = torques_nm.get(full_name, 0.0)
# Send MIT control command with gravity compensation torque
follower.bus_left._mit_control(
motor=motor,
kp=0.0, # No position control
kd=0.0, # No velocity damping
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print status every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
loop_times = []
last_print_time = loop_end
time.sleep(0.005)
except KeyboardInterrupt:
print("\n\nStopping gravity compensation...")
finally:
print("\nDisabling all motors and disconnecting...")
follower.bus_right.disable_torque()
follower.bus_left.disable_torque()
time.sleep(0.1)
follower.disconnect()
print("✓ Safe shutdown complete")
if __name__ == "__main__":
main()

View File

@@ -1,395 +0,0 @@
"""
OpenArms Dataset Recording with Gravity + Friction Compensation
Records a dataset using OpenArms follower robot with leader teleoperator.
Leader arms have gravity and friction compensation for weightless, easy movement.
Includes 3 cameras: left wrist, right wrist, and base camera.
Uses the same compensation approach as teleop_with_compensation.py
"""
import shutil
import time
from pathlib import Path
import numpy as np
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
# Recording parameters
NUM_EPISODES = 1
FPS = 30
EPISODE_TIME_SEC = 600
RESET_TIME_SEC = 120
TASK_DESCRIPTION = "OpenArms task description"
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
FRICTION_SCALE = 1.0
def record_loop_with_compensation(
robot,
leader,
events,
fps,
dataset,
dataset_features,
control_time_s,
single_task,
display_data=True,
):
"""
Custom record loop that applies gravity + friction compensation to leader.
Based on record_loop but with integrated compensation.
"""
dt = 1 / fps
episode_start_time = time.perf_counter()
# All joints (both arms)
all_joints = []
for motor in leader.bus_right.motors:
all_joints.append(f"right_{motor}")
for motor in leader.bus_left.motors:
all_joints.append(f"left_{motor}")
while True:
loop_start = time.perf_counter()
elapsed = loop_start - episode_start_time
# Check if we should exit
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
break
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities in degrees
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity torques for leader using built-in method
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
# Calculate friction torques for leader using built-in method
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=FRICTION_SCALE
)
# Combine gravity + friction torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower (both arms)
follower_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
# Send action to robot
if follower_action:
robot.send_action(follower_action)
# Get observation from robot (includes camera images)
observation = robot.get_observation()
# Add to dataset if we have a dataset
if dataset is not None:
# Build properly formatted observation frame
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
# Combine into single frame
frame = {**obs_frame, **action_frame}
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
frame["task"] = single_task
dataset.add_frame(frame)
# Display data if requested
if display_data:
log_rerun_data(observation=observation, action=follower_action)
# Maintain loop rate
loop_duration = time.perf_counter() - loop_start
sleep_time = dt - loop_duration
if sleep_time > 0:
time.sleep(sleep_time)
def main():
"""Main recording loop with gravity compensation."""
print("=" * 70)
print("OpenArms Dataset Recording with Compensation")
print("=" * 70)
# Create camera configurations (3 cameras: left wrist, right wrist, base)
# Using actual device paths found by lerobot-find-cameras opencv
camera_config = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
}
# Configure follower robot with cameras
follower_config = OpenArmsFollowerConfig(
port_left="can2",
port_right="can3",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
cameras=camera_config,
)
# Configure leader teleoperator (no cameras needed)
leader_config = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
# Initialize robot and teleoperator
print("\nInitializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
# Connect devices
print("Connecting and calibrating...")
follower.connect(calibrate=True)
leader.connect(calibrate=True)
# Verify URDF is loaded for gravity compensation
if leader.pin_robot is None:
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
# Configure the dataset features
# For actions, we only want to record positions (not velocity or torque)
action_features_hw = {}
for key, value in follower.action_features.items():
if key.endswith(".pos"):
action_features_hw[key] = value
action_features = hw_to_dataset_features(action_features_hw, "action")
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
print("\nCreating dataset...")
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
# Check if dataset already exists and prompt user
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
while dataset_path.exists():
print(f"\nDataset already exists at: {dataset_path}")
print("\nOptions:")
print(" 1. Overwrite existing dataset")
print(" 2. Use a different name")
print(" 3. Abort")
choice = input("\nEnter your choice (1/2/3): ").strip()
if choice == '1':
print(f"Removing existing dataset...")
shutil.rmtree(dataset_path)
print("✓ Existing dataset removed")
break
elif choice == '2':
print("\nCurrent repo_id:", repo_id)
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
if new_repo_id and '/' in new_repo_id:
repo_id = new_repo_id
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
print(f"✓ Using new repo_id: {repo_id}")
# Loop will continue if this new path also exists
else:
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
elif choice == '3':
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
follower.disconnect()
leader.disconnect()
return
else:
print("Invalid choice. Please enter 1, 2, or 3.")
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=FPS,
features=dataset_features,
robot_type=follower.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize keyboard listener and visualization
_, events = init_keyboard_listener()
init_rerun(session_name="openarms_recording")
# Enable motors on both leader arms for gravity compensation
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print("\n" + "=" * 70)
print(f"Recording {NUM_EPISODES} episodes")
print(f"Task: {TASK_DESCRIPTION}")
print("=" * 70)
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
print("\nKeyboard controls:")
print(" - Press 'q' to stop recording")
print(" - Press 'r' to re-record current episode")
print("=" * 70)
episode_idx = 0
try:
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Record episode with compensation active
record_loop_with_compensation(
robot=follower,
leader=leader,
events=events,
fps=FPS,
dataset=dataset,
dataset_features=dataset_features,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop_with_compensation(
robot=follower,
leader=leader,
events=events,
fps=FPS,
dataset=None, # Don't save reset period
dataset_features=dataset_features,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Handle re-recording
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Only save episode if frames were recorded
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
dataset.save_episode()
episode_idx += 1
else:
log_say("No frames recorded, skipping episode save")
# Clear the empty buffer
dataset.episode_buffer = None
except KeyboardInterrupt:
print("\n\nStopping recording...")
finally:
# Clean up
log_say("Stop recording")
try:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
print("✓ Shutdown complete")
except Exception as e:
print(f"Shutdown error: {e}")
# Upload dataset
print("\nUploading dataset to Hugging Face Hub...")
try:
dataset.push_to_hub()
print("✓ Dataset uploaded successfully")
except Exception as e:
print(f"Warning: Failed to upload dataset: {e}")
print("You can manually upload later using: dataset.push_to_hub()")
print("✓ Recording complete!")
if __name__ == "__main__":
main()

View File

@@ -1,166 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenArms Dataset Replay Example
Replays position actions from a recorded dataset on an OpenArms follower robot.
Only position commands (ending with .pos) are replayed, not velocity or torque.
Example usage:
python examples/openarms/replay.py
"""
import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
# Configuration
EPISODE_IDX = 0
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
DATASET_ROOT = None # Use default cache location, or specify custom path
# Robot configuration - adjust these to match your setup
ROBOT_CONFIG = OpenArmsFollowerConfig(
port_left="can2", # CAN interface for left arm
port_right="can3", # CAN interface for right arm
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0, # Safety limit: max degrees to move per step
)
def main():
"""Main replay function."""
print("=" * 70)
print("OpenArms Dataset Replay")
print("=" * 70)
print(f"\nDataset: {DATASET_REPO_ID}")
print(f"Episode: {EPISODE_IDX}")
print(f"Robot: {ROBOT_CONFIG.id}")
print(f" Left arm: {ROBOT_CONFIG.port_left}")
print(f" Right arm: {ROBOT_CONFIG.port_right}")
print("\n" + "=" * 70)
# Initialize the robot
print("\n[1/3] Initializing robot...")
robot = OpenArmsFollower(ROBOT_CONFIG)
# Load the dataset
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
dataset = LeRobotDataset(
DATASET_REPO_ID,
root=DATASET_ROOT,
episodes=[EPISODE_IDX]
)
# Filter dataset to only include frames from the specified episode
# (required for dataset V3.0 where episodes are chunked)
episode_frames = dataset.hf_dataset.filter(
lambda x: x["episode_index"] == EPISODE_IDX
)
if len(episode_frames) == 0:
raise ValueError(
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
)
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
# Extract action features from dataset
action_features = dataset.features.get(ACTION, {})
action_names = action_features.get("names", [])
# Filter to only position actions (ending with .pos)
position_action_names = [name for name in action_names if name.endswith(".pos")]
if not position_action_names:
raise ValueError(
f"No position actions found in dataset. Action names: {action_names}"
)
print(f" Found {len(position_action_names)} position actions to replay")
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
# Select only action columns from dataset
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
print(f"\n[3/3] Connecting to robot...")
robot.connect(calibrate=False) # Skip calibration for replay
if not robot.is_connected:
raise RuntimeError("Robot failed to connect!")
print("\n" + "=" * 70)
print("Ready to replay!")
print("=" * 70)
print("\nThe robot will replay the recorded positions.")
print("Press Ctrl+C to stop at any time.\n")
input("Press ENTER to start replaying...")
# Replay loop
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
try:
for idx in range(len(episode_frames)):
loop_start = time.perf_counter()
# Extract action array from dataset
action_array = actions[idx][ACTION]
# Build action dictionary, but only include position actions
action = {}
for i, name in enumerate(action_names):
# Only include position actions (ending with .pos)
if name.endswith(".pos"):
action[name] = float(action_array[i])
# Send action to robot
robot.send_action(action)
# Maintain replay rate (use dataset fps)
loop_duration = time.perf_counter() - loop_start
dt_s = 1.0 / dataset.fps - loop_duration
busy_wait(dt_s)
# Progress indicator every 100 frames
if (idx + 1) % 100 == 0:
progress = (idx + 1) / len(episode_frames) * 100
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
log_say("Replay complete", blocking=True)
except KeyboardInterrupt:
print("\n\nReplay interrupted by user")
finally:
# Disconnect robot
print("\nDisconnecting robot...")
robot.disconnect()
print("✓ Replay complete!")
if __name__ == "__main__":
main()

View File

@@ -1,73 +0,0 @@
#!/bin/bash
# Setup all OpenArms CAN interfaces with CAN FD
set -e
echo "=========================================="
echo "OpenArms CAN FD Interface Setup"
echo "=========================================="
echo ""
echo "Mode: CAN FD"
echo " - Nominal bitrate: 1 Mbps"
echo " - Data bitrate: 5 Mbps"
echo ""
echo "Configuring interfaces can0, can1, can2, can3..."
echo ""
# Configure each CAN interface with CAN FD
for i in 0 1 2 3; do
interface="can$i"
# Check if interface exists
if ! ip link show "$interface" &> /dev/null; then
echo "$interface: Not found, skipping"
continue
fi
# Bring down interface
sudo ip link set "$interface" down 2>/dev/null
# Configure CAN FD mode
sudo ip link set "$interface" type can \
bitrate 1000000 \
dbitrate 5000000 \
fd on
# Bring up interface
sudo ip link set "$interface" up
# Verify configuration
if ip link show "$interface" | grep -q "UP"; then
echo "$interface: Configured and UP"
else
echo "$interface: Failed to bring UP"
fi
done
echo ""
echo "=========================================="
echo "Verification"
echo "=========================================="
echo ""
# Show detailed status for each interface
for i in 0 1 2 3; do
interface="can$i"
if ip link show "$interface" &> /dev/null; then
echo "$interface:"
# Show key parameters
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
echo ""
fi
done
echo "=========================================="
echo "Setup Complete!"
echo "=========================================="
echo ""
echo "All interfaces configured for CAN FD mode"
echo ""
echo "Next steps:"
echo " 1. Test motors: python debug_can_communication.py"
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
echo ""

View File

@@ -1,148 +0,0 @@
"""
OpenArms Teleoperation Example - Full Dual Arms
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
It first calibrates both devices, then enters a teleoperation loop for both arms.
"""
import time
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
follower_config = OpenArmsFollowerConfig(
port_left="can2", # CAN interface for follower left arm
port_right="can3", # CAN interface for follower right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0, # Safety limit
)
leader_config = OpenArmsLeaderConfig(
port_left="can0", # CAN interface for leader left arm
port_right="can1", # CAN interface for leader right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_leader",
manual_control=True, # Enable manual control (torque disabled)
)
print("=" * 60)
print("OpenArms Teleoperation - Full Dual Arms")
print("=" * 60)
# Initialize devices
print("\n[1/4] Initializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
# Connect and calibrate follower
print("\n[2/4] Connecting and calibrating follower robot...")
print("Note: If you have existing calibration, just press ENTER to use it.")
follower.connect(calibrate=True)
# Connect and calibrate leader
print("\n[3/4] Connecting and calibrating leader arm...")
print("Note: The leader arm will have torque disabled for manual control.")
leader.connect(calibrate=True)
# Wait for user to be ready
print("\n[4/4] Ready for teleoperation!")
print("\nBoth arms will be controlled (16 motors total):")
print(" RIGHT ARM: joints 1-7 + gripper")
print(" LEFT ARM: joints 1-7 + gripper")
print("\nPress ENTER to start teleoperation...")
input()
print("\nTeleoperation started! Move both leader arms.")
print("Press Ctrl+C to stop.\n")
# All joints for both arms (16 motors total)
all_joints = [
# Right arm
"right_joint_1",
"right_joint_2",
"right_joint_3",
"right_joint_4",
"right_joint_5",
"right_joint_6",
"right_joint_7",
"right_gripper",
# Left arm
"left_joint_1",
"left_joint_2",
"left_joint_3",
"left_joint_4",
"left_joint_5",
"left_joint_6",
"left_joint_7",
"left_gripper",
]
# Performance monitoring
loop_times = []
start_time = time.perf_counter()
last_print_time = start_time
try:
while True:
loop_start = time.perf_counter()
# Get action from leader
leader_action = leader.get_action()
# Filter to only position data for all joints (both arms)
joint_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
joint_action[pos_key] = leader_action[pos_key]
# Send action to follower (both arms)
if joint_action:
follower.send_action(joint_action)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print stats every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
min_time = min(loop_times)
max_time = max(loop_times)
max_hz = 1.0 / min_time if min_time > 0 else 0
min_hz = 1.0 / max_time if max_time > 0 else 0
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
f"Avg loop time: {avg_time*1000:.1f} ms")
# Reset for next measurement window
loop_times = []
last_print_time = loop_end
except KeyboardInterrupt:
print("\n\nStopping teleoperation...")
finally:
# Disconnect devices
print("Disconnecting devices...")
try:
follower.disconnect()
except Exception as e:
print(f"Error disconnecting follower: {e}")
try:
leader.disconnect()
except Exception as e:
print(f"Error disconnecting leader: {e}")
print("Done!")

View File

@@ -1,197 +0,0 @@
"""
OpenArms Mini Teleoperation Example
This script demonstrates teleoperation of an OpenArms follower robot using
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
The OpenArms Mini has:
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
Note on gripper normalization:
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
- OpenArms follower gripper: degrees (0=closed, -65=open)
- This script automatically converts between the two ranges
"""
import time
import os
import sys
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
from lerobot.utils.robot_utils import busy_wait
# Target control frequency
TARGET_FPS = 30
# Configure the OpenArms follower (Damiao motors on CAN bus)
follower_config = OpenArmsFollowerConfig(
port_left="can0", # CAN interface for follower left arm
port_right="can1", # CAN interface for follower right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0, # Safety limit (degrees per step)
)
# Configure the OpenArms Mini leader (Feetech motors on serial)
leader_config = OpenArmsMiniConfig(
port_right="/dev/ttyACM0", # Serial port for right arm
port_left="/dev/ttyACM1", # Serial port for left arm
id="openarms_mini",
use_degrees=True,
)
print("OpenArms Mini → OpenArms Follower Teleoperation")
# Initialize devices
follower = OpenArmsFollower(follower_config)
leader = OpenArmsMini(leader_config)
# Connect and calibrate follower
print("Note: If you have existing calibration, just press ENTER to use it.")
follower.connect(calibrate=True)
# Connect and calibrate leader
print("Note: The leader arms will have torque disabled for manual control.")
leader.connect(calibrate=True)
print("\nPress ENTER to start teleoperation...")
input()
print("Press Ctrl+C to stop.\n")
# All joints for both arms (16 motors total)
all_joints = [
# Right arm
"right_joint_1",
"right_joint_2",
"right_joint_3",
"right_joint_4",
"right_joint_5",
"right_joint_6",
"right_joint_7",
"right_gripper",
# Left arm
"left_joint_1",
"left_joint_2",
"left_joint_3",
"left_joint_4",
"left_joint_5",
"left_joint_6",
"left_joint_7",
"left_gripper",
]
# Performance monitoring
loop_times = []
avg_loop_time = 0.0
min_loop_time = float('inf')
max_loop_time = 0.0
stats_update_interval = 1.0 # Update stats every 1 second
last_stats_update = time.perf_counter()
SWAPPED_JOINTS = {
"right_joint_6": "right_joint_7",
"right_joint_7": "right_joint_6",
"left_joint_6": "left_joint_7",
"left_joint_7": "left_joint_6",
}
try:
while True:
loop_start = time.perf_counter()
# Get actions and observations
leader_action = leader.get_action()
follower_obs = follower.get_observation()
joint_action = {}
for joint in all_joints:
leader_key = f"{joint}.pos"
# Determine which follower joint this leader joint controls
follower_joint = SWAPPED_JOINTS.get(joint, joint)
follower_key = f"{follower_joint}.pos"
# Get leader position (default 0 if missing)
pos = leader_action.get(leader_key, 0.0)
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
if "gripper" in joint:
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
# 0 (closed) -> 0°, 100 (open) -> -65°
pos = (pos / 100.0) * -65.0
# Store in action dict for follower
joint_action[follower_key] = pos
follower.send_action(joint_action)
# Loop timing
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Update stats periodically
current_time = time.perf_counter()
if current_time - last_stats_update >= stats_update_interval:
if loop_times:
avg_loop_time = sum(loop_times) / len(loop_times)
min_loop_time = min(loop_times)
max_loop_time = max(loop_times)
loop_times = []
last_stats_update = current_time
# Display everything
sys.stdout.write("\033[H\033[J") # Clear screen
# Show timing stats at the top
if avg_loop_time > 0:
avg_hz = 1.0 / avg_loop_time
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
else:
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
# Show joint positions
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
print("-" * 52)
for joint in all_joints:
leader_key = f"{joint}.pos"
follower_joint = SWAPPED_JOINTS.get(joint, joint)
follower_key = f"{follower_joint}.pos"
leader_pos = leader_action.get(leader_key, 0.0)
follower_pos = follower_obs.get(follower_key, 0.0)
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
# Smart sleep to maintain target FPS
dt_s = time.perf_counter() - loop_start
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
except KeyboardInterrupt:
print("\n\nStopping teleoperation...")
finally:
# Disconnect devices
print("Disconnecting devices...")
try:
follower.disconnect()
except Exception as e:
print(f"Error disconnecting follower: {e}")
try:
leader.disconnect()
except Exception as e:
print(f"Error disconnecting leader: {e}")
print("Done!")

View File

@@ -1,202 +0,0 @@
"""
OpenArms Teleoperation with Gravity + Friction Compensation
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
Follower arms (both LEFT and RIGHT): Mirror leader movements
Uses the URDF file from the lerobot repository.
"""
import time
import numpy as np
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
FRICTION_SCALE = 1.0
def main():
"""Main teleoperation loop with gravity compensation"""
print("=" * 70)
print("OpenArms Teleoperation with Gravity Compensation")
print("=" * 70)
# Configuration
follower_config = OpenArmsFollowerConfig(
port_left="can2",
port_right="can3",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
)
leader_config = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
# Initialize and connect
print("\nInitializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
follower.connect()
leader.connect()
# URDF is automatically loaded in the leader constructor
if leader.pin_robot is None:
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
print("Press ENTER to start...")
input()
# Enable motors on both leader arms for gravity compensation
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print("Press Ctrl+C to stop\n")
# Main control loop
loop_times = []
last_print_time = time.perf_counter()
# All joints (both arms)
all_joints = []
for motor in leader.bus_right.motors:
all_joints.append(f"right_{motor}")
for motor in leader.bus_left.motors:
all_joints.append(f"left_{motor}")
try:
while True:
loop_start = time.perf_counter()
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities in degrees
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity torques for leader using built-in method
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
# Calculate friction torques for leader using built-in method
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=FRICTION_SCALE
)
# Combine gravity + friction torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower (both arms)
follower_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
if follower_action:
follower.send_action(follower_action)
# Performance monitoring
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
loop_times = []
last_print_time = loop_end
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
try:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
print("✓ Shutdown complete")
except Exception as e:
print(f"Shutdown error: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,745 +0,0 @@
body {
margin: 0;
padding: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: #f5f5f5;
}
main {
min-height: 100vh;
padding: 2rem;
}
header {
text-align: center;
margin-bottom: 2rem;
}
h1 {
font-size: 2rem;
font-weight: 600;
color: #333;
margin: 0;
}
h2 {
font-size: 1.25rem;
font-weight: 600;
color: #333;
margin: 0 0 1rem 0;
}
h3 {
font-size: 0.875rem;
font-weight: 600;
color: #666;
margin: 0 0 0.5rem 0;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.container {
max-width: 1920px;
margin: 0 auto;
display: grid;
grid-template-columns: minmax(500px, 600px) 1fr;
gap: 2rem;
align-items: start;
}
/* Left column container */
.left-column {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
/* Right column container */
.right-column {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
/* Responsive: Stack on smaller screens */
@media (max-width: 1200px) {
.container {
grid-template-columns: 1fr;
}
}
.panel {
background: white;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.config-panel {
border: 2px solid #e5e7eb;
}
.config-header {
display: flex;
justify-content: space-between;
align-items: center;
cursor: pointer;
user-select: none;
padding: 0.5rem 0;
}
.config-header:hover {
opacity: 0.7;
}
.toggle-icon {
font-size: 1rem;
color: #6b7280;
transition: transform 0.2s;
}
.config-content {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.robot-setup {
margin-bottom: 0.5rem;
}
.robot-status {
display: flex;
align-items: center;
justify-content: space-between;
padding: 1rem;
border-radius: 6px;
font-weight: 500;
gap: 1rem;
}
.robot-status.ready {
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
color: #065f46;
border: 1px solid #10b981;
}
.robot-status.not-ready {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
color: #92400e;
border: 1px solid #f59e0b;
}
.btn-setup {
background: #10b981;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-setup:hover:not(:disabled) {
background: #059669;
}
.btn-setup:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-zero {
background: #8b5cf6;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-zero:hover:not(:disabled) {
background: #7c3aed;
}
.btn-zero:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.zero-position-section {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.btn-zero-large {
width: 100%;
background: #8b5cf6;
color: white;
border: none;
padding: 0.875rem 1.5rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
}
.btn-zero-large:hover:not(:disabled) {
background: #7c3aed;
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
transform: translateY(-1px);
}
.btn-zero-large:disabled {
background: #d1d5db;
cursor: not-allowed;
box-shadow: none;
transform: none;
}
.delete-episode-section {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.btn-delete {
width: 100%;
background: #ef4444;
color: white;
border: none;
padding: 0.875rem 1.5rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
}
.btn-delete:hover:not(:disabled) {
background: #dc2626;
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
transform: translateY(-1px);
}
.btn-delete:disabled {
background: #d1d5db;
cursor: not-allowed;
box-shadow: none;
transform: none;
}
.delete-info {
margin-top: 0.5rem;
font-size: 0.875rem;
color: #666;
text-align: center;
font-style: italic;
}
.btn-disconnect {
background: #ef4444;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-disconnect:hover {
background: #dc2626;
}
.btn-refresh {
background: #3b82f6;
color: white;
border: none;
padding: 0.4rem 0.8rem;
border-radius: 4px;
font-size: 0.75rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-refresh:hover:not(:disabled) {
background: #2563eb;
}
.btn-refresh:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.control-panel {
border: 2px solid #10b981;
}
.status-banner {
display: flex;
align-items: center;
gap: 1rem;
padding: 1rem 1.5rem;
border-radius: 6px;
margin-bottom: 1.5rem;
font-weight: 500;
font-size: 0.95rem;
}
.status-banner.initializing {
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
color: #1e40af;
border-left: 4px solid #3b82f6;
}
.status-banner.encoding {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
color: #92400e;
border-left: 4px solid #f59e0b;
}
.status-banner.uploading {
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
color: #3730a3;
border-left: 4px solid #6366f1;
}
.status-banner.success {
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
color: #065f46;
border-left: 4px solid #10b981;
}
.status-banner.warning {
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
color: #991b1b;
border-left: 4px solid #ef4444;
}
.spinner {
width: 20px;
height: 20px;
border: 3px solid rgba(0, 0, 0, 0.1);
border-top-color: currentColor;
border-radius: 50%;
animation: spin 0.8s linear infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.control-horizontal {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.control-left {
display: flex;
flex-direction: column;
gap: 1rem;
}
.control-right {
display: flex;
align-items: center;
justify-content: center;
}
.input-group {
display: flex;
gap: 0.5rem;
margin-bottom: 0;
}
input[type="text"] {
flex: 1;
padding: 0.75rem;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 1rem;
}
input[type="text"]:disabled {
background: #f5f5f5;
cursor: not-allowed;
}
input[type="text"]:focus {
outline: none;
border-color: #10b981;
}
button {
padding: 0.75rem 1.5rem;
border: none;
border-radius: 4px;
font-size: 1rem;
font-weight: 500;
cursor: pointer;
transition: all 0.2s;
}
.btn-set-task {
background: #3b82f6;
color: white;
min-width: 120px;
}
.btn-set-task:hover:not(:disabled) {
background: #2563eb;
}
.btn-set-task:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-start {
background: #10b981;
color: white;
}
.btn-start:hover:not(:disabled) {
background: #059669;
}
.btn-start:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-stop {
background: #ef4444;
color: white;
}
.btn-stop:hover {
background: #dc2626;
}
.btn-reset {
padding: 0.5rem 1rem;
background: #6b7280;
color: white;
font-size: 0.875rem;
}
.btn-reset:hover {
background: #4b5563;
}
.status {
display: flex;
align-items: center;
gap: 0.75rem;
padding: 1rem;
border-radius: 4px;
margin-bottom: 1rem;
}
.status.recording {
background: #fee2e2;
color: #991b1b;
}
.status.recording.recording-active {
display: flex;
flex-direction: column;
gap: 1rem;
background: #dc2626;
color: white;
padding: 1.5rem;
border: 4px solid #991b1b;
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
font-weight: 700;
font-size: 1rem;
}
.status.recording.recording-active .indicator {
width: 20px;
height: 20px;
background: #fef2f2;
animation: pulse-strong 1s ease-in-out infinite;
}
@keyframes pulse-strong {
0%, 100% {
opacity: 1;
transform: scale(1);
}
50% {
opacity: 0.7;
transform: scale(1.1);
}
}
.status.recording.recording-active .time-display {
display: flex;
flex-direction: column;
gap: 0.5rem;
font-size: 1.5rem;
font-weight: 700;
color: white;
}
.fps-display {
font-size: 1rem;
font-weight: 500;
opacity: 0.95;
}
.fps-warning {
color: #fef2f2;
animation: pulse-warning 1s ease-in-out infinite;
}
@keyframes pulse-warning {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.status.recording.recording-active .btn-stop {
align-self: stretch;
}
.ramp-up-countdown {
display: flex;
justify-content: center;
margin-bottom: 1rem;
}
.countdown-box {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 2rem 3rem;
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
border: 4px solid #f59e0b;
border-radius: 16px;
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
min-width: 280px;
animation: pulse-warm 1.5s ease-in-out infinite;
}
@keyframes pulse-warm {
0%, 100% {
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
}
50% {
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
}
}
.countdown-label {
font-size: 1rem;
color: #92400e;
text-transform: uppercase;
letter-spacing: 1.5px;
font-weight: 800;
margin-bottom: 1rem;
text-align: center;
}
.countdown-value {
font-size: 4.5rem;
font-weight: 900;
color: #d97706;
font-family: 'Courier New', monospace;
line-height: 1;
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
margin-bottom: 0.5rem;
}
.countdown-subtitle {
font-size: 0.875rem;
color: #78350f;
font-weight: 600;
font-style: italic;
text-align: center;
margin-top: 0.5rem;
}
.status.idle {
background: #f3f4f6;
color: #374151;
}
.indicator {
width: 12px;
height: 12px;
border-radius: 50%;
background: #ef4444;
animation: pulse 1.5s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.counter {
display: flex;
flex-direction: column;
align-items: center;
gap: 0.75rem;
padding: 1.5rem;
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
border-radius: 8px;
border: 2px solid #e5e7eb;
min-width: 200px;
}
.counter-label {
font-size: 0.75rem;
color: #6b7280;
text-transform: uppercase;
letter-spacing: 0.5px;
font-weight: 600;
}
.counter-value {
font-size: 3rem;
font-weight: 700;
color: #10b981;
line-height: 1;
}
.time-display {
font-size: 1.5rem;
font-weight: 600;
font-family: 'Courier New', monospace;
}
.error-box {
padding: 1rem;
background: #fee2e2;
color: #991b1b;
border-radius: 4px;
border-left: 4px solid #ef4444;
font-size: 0.875rem;
}
.config-section {
margin-bottom: 1.5rem;
}
.config-section:last-child {
margin-bottom: 0;
}
.config-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 1rem;
}
label {
display: flex;
flex-direction: column;
gap: 0.5rem;
font-size: 0.875rem;
color: #374151;
font-weight: 500;
}
select {
padding: 0.5rem;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 0.875rem;
background: white;
}
select:disabled {
background: #f5f5f5;
cursor: not-allowed;
}
/* Camera Layout */
.camera-layout {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.camera-base {
width: 100%;
}
.camera-wrist-container {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 1.5rem;
}
.camera-wrist {
width: 100%;
}
.camera {
border: 1px solid #e5e7eb;
border-radius: 4px;
overflow: hidden;
}
.camera h3 {
padding: 0.75rem;
background: #f9fafb;
border-bottom: 1px solid #e5e7eb;
margin: 0;
}
.camera img {
width: 100%;
height: auto;
display: block;
background: #000;
min-height: 300px;
object-fit: cover;
}
.camera-placeholder {
text-align: center;
padding: 4rem 2rem;
background: #f9fafb;
border-radius: 4px;
border: 2px dashed #d1d5db;
}
.camera-placeholder p {
margin: 0.5rem 0;
font-size: 1rem;
color: #6b7280;
}
.camera-placeholder p:first-child {
font-size: 1.25rem;
font-weight: 500;
color: #374151;
}
.hint {
margin-top: 0.5rem;
font-size: 0.75rem;
color: #6b7280;
display: flex;
align-items: center;
gap: 0.5rem;
flex-wrap: wrap;
}

View File

@@ -1,857 +0,0 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import './App.css';
const API_BASE = 'http://localhost:8000/api';
function App() {
// State
const [task, setTask] = useState('');
const [isRecording, setIsRecording] = useState(false);
const [isInitializing, setIsInitializing] = useState(false);
const [isEncoding, setIsEncoding] = useState(false);
const [isUploading, setIsUploading] = useState(false);
const [robotsReady, setRobotsReady] = useState(false);
const [elapsedTime, setElapsedTime] = useState(0);
const [currentFps, setCurrentFps] = useState(0);
const [loopFps, setLoopFps] = useState(0);
const [episodeCount, setEpisodeCount] = useState(0);
const [error, setError] = useState(null);
const [statusMessage, setStatusMessage] = useState('Ready');
const [uploadStatus, setUploadStatus] = useState(null);
const [rampUpRemaining, setRampUpRemaining] = useState(0);
const [movingToZero, setMovingToZero] = useState(false);
const [configExpanded, setConfigExpanded] = useState(false);
const [latestRepoId, setLatestRepoId] = useState(null);
// Configuration
const [config, setConfig] = useState({
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
leader_left: 'can0',
leader_right: 'can1',
follower_left: 'can2',
follower_right: 'can3',
left_wrist: '/dev/video0',
right_wrist: '/dev/video1',
base: '/dev/video4'
});
// Available options
const [availableCameras, setAvailableCameras] = useState([]);
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
const statusIntervalRef = useRef(null);
const hasInitializedRef = useRef(false);
const loadConfig = () => {
try {
const saved = localStorage.getItem('openarms_config');
if (saved) {
const loadedConfig = JSON.parse(saved);
setConfig(prev => ({ ...prev, ...loadedConfig }));
}
} catch (e) {
console.error('Load config error:', e);
}
};
const saveConfig = (newConfig) => {
try {
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
} catch (e) {
console.error('Save config error:', e);
}
};
// Fetch status periodically
const fetchStatus = async () => {
try {
const response = await fetch(`${API_BASE}/status`);
const data = await response.json();
setIsRecording(data.is_recording);
setIsInitializing(data.is_initializing);
setIsEncoding(data.is_encoding);
setIsUploading(data.is_uploading);
setRobotsReady(data.robots_ready);
setElapsedTime(data.elapsed_time);
setCurrentFps(data.current_fps || 0);
setLoopFps(data.loop_fps || 0);
setEpisodeCount(data.episode_count);
setError(data.error);
setStatusMessage(data.status_message || 'Ready');
setUploadStatus(data.upload_status);
setRampUpRemaining(data.ramp_up_remaining || 0);
setMovingToZero(data.moving_to_zero || false);
// Track the latest repo_id from the backend
if (data.latest_repo_id) {
setLatestRepoId(data.latest_repo_id);
}
if (data.config) {
// Only merge server config if we don't have a saved config (first load)
if (!localStorage.getItem('openarms_config')) {
setConfig(prev => {
const merged = { ...data.config, ...prev };
localStorage.setItem('openarms_config', JSON.stringify(merged));
return merged;
});
}
}
} catch (e) {
console.error('Failed to fetch status:', e);
}
};
const setupRobots = async () => {
// Show warning to verify camera positions
const confirmed = window.confirm(
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
'📹 Check that cameras are correctly positioned:\n' +
' • LEFT wrist camera is actually on the LEFT arm\n' +
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
' • BASE camera is actually the BASE/overhead camera\n\n' +
'Incorrect camera positioning will result in invalid training data!\n\n' +
'Click OK to continue with robot setup, or Cancel to review configuration.'
);
if (!confirmed) {
return; // User cancelled, don't proceed
}
setError(null);
try {
const response = await fetch(`${API_BASE}/robots/setup`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(config)
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to setup robots');
}
await response.json();
saveConfig(config);
} catch (e) {
setError(`Robot setup failed: ${e.message}`);
}
};
// Disconnect robots
const disconnectRobots = async () => {
try {
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
setRobotsReady(false);
} catch (e) {
console.error('Failed to disconnect robots:', e);
}
};
// Discover cameras
const discoverCameras = async () => {
try {
const response = await fetch(`${API_BASE}/cameras/discover`);
const data = await response.json();
const cameras = data.cameras || [];
setAvailableCameras(cameras);
// Get list of valid camera IDs
const validCameraIds = cameras.map(cam => String(cam.id));
// Auto-fix config if current values are invalid or not set
const updated = { ...config };
let changed = false;
// Auto-fix invalid camera config
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
if (cameras.length >= 1) {
updated.left_wrist = String(cameras[0].id);
changed = true;
}
}
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
if (cameras.length >= 2) {
updated.right_wrist = String(cameras[1].id);
changed = true;
}
}
if (!config.base || !validCameraIds.includes(config.base)) {
if (cameras.length >= 3) {
updated.base = String(cameras[2].id);
changed = true;
}
}
if (changed) {
setConfig(updated);
saveConfig(updated);
}
if (cameras.length === 0) {
setError('No cameras detected! Please connect cameras and refresh.');
}
} catch (e) {
console.error('Failed to discover cameras:', e);
setError(`Camera discovery failed: ${e.message}`);
}
};
// Discover USB ports
const discoverUsbPorts = async () => {
try {
const response = await fetch(`${API_BASE}/usb/discover`);
const data = await response.json();
const ports = data.ports || [];
setAvailableUsbPorts(ports);
// Auto-fix config if OpenArms Mini is selected and ports are invalid
if (config.leader_type === 'openarms_mini') {
const updated = { ...config };
let changed = false;
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
updated.leader_left = ports[0];
changed = true;
}
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
updated.leader_right = ports[1];
changed = true;
}
if (changed) {
setConfig(updated);
saveConfig(updated);
}
}
if (ports.length === 0) {
console.warn('No USB ports detected for OpenArms Mini');
}
} catch (e) {
console.error('Failed to discover USB ports:', e);
}
};
// Set task only (for pedal use)
const setTaskOnly = async () => {
if (!task.trim()) {
setError('Please enter a task description');
return;
}
setError(null);
try {
const response = await fetch(`${API_BASE}/recording/set-task`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ task, ...config })
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to set task');
}
const result = await response.json();
setStatusMessage(result.message || `Task set: ${task}`);
saveConfig(config);
// Clear success message after 3 seconds
setTimeout(() => {
if (!isRecording && !isInitializing) {
setStatusMessage('Ready');
}
}, 3000);
} catch (e) {
setError(e.message);
}
};
// Start recording
const startRecording = async () => {
if (!task.trim()) {
setError('Please enter a task description');
return;
}
setError(null);
try {
const response = await fetch(`${API_BASE}/recording/start`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ task, ...config })
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to start recording');
}
await response.json();
saveConfig(config);
} catch (e) {
setError(e.message);
}
};
// Stop recording
const stopRecording = async () => {
try {
const response = await fetch(`${API_BASE}/recording/stop`, {
method: 'POST'
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to stop recording');
}
const data = await response.json();
setError(null);
// Update latest repo_id after recording
if (data.dataset_name) {
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
}
} catch (e) {
setError(e.message);
}
};
const deleteLatestEpisode = async () => {
if (!latestRepoId) {
setError('No episode to delete');
return;
}
const confirmed = window.confirm(
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
);
if (!confirmed) {
return;
}
try {
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to delete episode');
}
const data = await response.json();
setLatestRepoId(null);
setEpisodeCount(Math.max(0, episodeCount - 1));
setStatusMessage(`Deleted: ${data.deleted_repo}`);
setTimeout(() => {
if (!isRecording && !isInitializing) {
setStatusMessage('Ready');
}
}, 3000);
} catch (e) {
setError(`Delete failed: ${e.message}`);
}
};
// Reset counter
const resetCounter = async () => {
try {
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
setEpisodeCount(0);
} catch (e) {
console.error('Failed to reset counter:', e);
}
};
// Move robot to zero position
const moveToZero = async () => {
setError(null);
try {
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to move to zero position');
}
await response.json();
} catch (e) {
setError(`Move to zero failed: ${e.message}`);
}
};
// Format time as MM:SS
const formatTime = (seconds) => {
const mins = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
};
// Update config and save
const updateConfig = (key, value) => {
const updated = { ...config, [key]: value };
setConfig(updated);
saveConfig(updated);
};
// Initialize on mount only
useEffect(() => {
// Prevent double-initialization in development
if (hasInitializedRef.current) {
return;
}
hasInitializedRef.current = true;
loadConfig();
discoverCameras();
discoverUsbPorts();
fetchStatus();
statusIntervalRef.current = setInterval(fetchStatus, 1000);
return () => {
if (statusIntervalRef.current) {
clearInterval(statusIntervalRef.current);
}
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []); // Run only once on mount
// Discover USB ports when leader type changes to Mini
useEffect(() => {
if (config.leader_type === 'openarms_mini') {
discoverUsbPorts();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [config.leader_type]);
return (
<main>
<header>
<h1>OpenArms Recording</h1>
</header>
<div className="container">
{/* Left Column: Configuration and Recording Control */}
<div className="left-column">
{/* Configuration Panel */}
<section className="panel config-panel">
<div
className="config-header"
onClick={() => setConfigExpanded(!configExpanded)}
role="button"
tabIndex={0}
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
>
<h2> Configuration</h2>
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
</div>
{configExpanded && (
<div className="config-content">
{/* Robot Setup */}
<div className="config-section">
<h3>🤖 Robot Setup</h3>
<div className="robot-setup">
{robotsReady ? (
<div className="robot-status ready">
<span> Robots Ready - Recording will start instantly</span>
<button onClick={disconnectRobots} className="btn-disconnect">
Disconnect Robots
</button>
</div>
) : (
<div className="robot-status not-ready">
<span> Robots not initialized - Recording will take ~10 seconds</span>
<button
onClick={setupRobots}
disabled={isRecording || isInitializing}
className="btn-setup"
>
🚀 Setup Robots
</button>
</div>
)}
</div>
</div>
{/* Leader Type Selection */}
<div className="config-section">
<h3>🎮 Leader Type</h3>
<div className="config-grid">
<label style={{gridColumn: '1 / -1'}}>
Leader Arm Type
<select
value={config.leader_type}
onChange={(e) => updateConfig('leader_type', e.target.value)}
disabled={isRecording || robotsReady}
>
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
</select>
</label>
</div>
</div>
{/* Leader Interfaces (CAN or USB based on type) */}
<div className="config-section">
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
<h3>
{config.leader_type === 'openarms_mini'
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
: 'Leader Interfaces (CAN)'}
</h3>
{config.leader_type === 'openarms_mini' && (
<button
onClick={discoverUsbPorts}
className="btn-refresh"
disabled={isRecording || robotsReady}
>
🔄 Refresh
</button>
)}
</div>
<div className="config-grid">
<label>
Leader Left
<select
value={config.leader_left}
onChange={(e) => updateConfig('leader_left', e.target.value)}
disabled={isRecording || robotsReady}
>
{config.leader_type === 'openarms_mini' ? (
availableUsbPorts.length > 0 ? (
availableUsbPorts.map((port) => (
<option key={port} value={port}>{port}</option>
))
) : (
<option value="">No USB ports detected</option>
)
) : (
canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))
)}
</select>
</label>
<label>
Leader Right
<select
value={config.leader_right}
onChange={(e) => updateConfig('leader_right', e.target.value)}
disabled={isRecording || robotsReady}
>
{config.leader_type === 'openarms_mini' ? (
availableUsbPorts.length > 0 ? (
availableUsbPorts.map((port) => (
<option key={port} value={port}>{port}</option>
))
) : (
<option value="">No USB ports detected</option>
)
) : (
canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))
)}
</select>
</label>
</div>
</div>
{/* Follower CAN Interfaces */}
<div className="config-section">
<h3>Follower Interfaces (CAN)</h3>
<div className="config-grid">
<label>
Follower Left
<select
value={config.follower_left}
onChange={(e) => updateConfig('follower_left', e.target.value)}
disabled={isRecording || robotsReady}
>
{canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))}
</select>
</label>
<label>
Follower Right
<select
value={config.follower_right}
onChange={(e) => updateConfig('follower_right', e.target.value)}
disabled={isRecording || robotsReady}
>
{canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))}
</select>
</label>
</div>
</div>
{/* Camera Configuration */}
<div className="config-section">
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
<button
onClick={discoverCameras}
className="btn-refresh"
disabled={isRecording || robotsReady}
>
🔄 Refresh
</button>
</div>
<div className="config-grid">
<label>
Left Wrist
<select
value={config.left_wrist}
onChange={(e) => updateConfig('left_wrist', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
<label>
Right Wrist
<select
value={config.right_wrist}
onChange={(e) => updateConfig('right_wrist', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
<label>
Base Camera
<select
value={config.base}
onChange={(e) => updateConfig('base', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
</div>
</div>
</div>
)}
</section>
{/* Control Panel */}
<section className="panel control-panel">
<h2>🎬 Recording Control</h2>
{/* Status Banner - Always show important statuses */}
{isInitializing && (
<div className="status-banner initializing">
<div className="spinner"></div>
<span>{statusMessage}</span>
</div>
)}
{isEncoding && (
<div className="status-banner encoding">
<div className="spinner"></div>
<span>📹 {statusMessage}</span>
</div>
)}
{isUploading && (
<div className="status-banner uploading">
<div className="spinner"></div>
<span> {statusMessage}</span>
</div>
)}
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
<span>{uploadStatus}</span>
</div>
)}
<div className="control-horizontal">
{/* Task Input and Status */}
<div className="control-left">
<div className="input-group">
<input
type="text"
value={task}
onChange={(e) => setTask(e.target.value)}
placeholder="Task description (e.g., 'pick and place')"
disabled={isRecording || isInitializing || isEncoding || isUploading}
onKeyPress={(e) => {
if (e.key === 'Enter' && robotsReady) {
setTaskOnly();
}
}}
/>
<button
onClick={setTaskOnly}
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
className="btn-set-task"
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
>
💾 Set Task
</button>
<button
onClick={startRecording}
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
className="btn-start"
title={!robotsReady ? 'Please setup robots first' : ''}
>
{isInitializing
? '⏳ Initializing...'
: isRecording
? '⏺ Recording...'
: robotsReady
? '⏺ Start Recording'
: '⏺ Setup Robots First'}
</button>
</div>
{/* Ramp-up Countdown */}
{isRecording && rampUpRemaining > 0 && (
<div className="ramp-up-countdown">
<div className="countdown-box">
<div className="countdown-label"> WARMING UP - PID RAMP-UP</div>
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
<div className="countdown-subtitle">Recording will start automatically...</div>
</div>
</div>
)}
{/* Recording Status - Only show after ramp-up */}
{isRecording && rampUpRemaining <= 0 && (
<div className="status recording recording-active">
<div className="indicator"></div>
<div className="time-display">
<span>{formatTime(elapsedTime)}</span>
<span className="fps-display">
Loop: {loopFps.toFixed(1)} Hz
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> </span>}
</span>
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
</div>
<button onClick={stopRecording} className="btn-stop">
Stop
</button>
</div>
)}
</div>
{/* Episode Counter */}
<div className="control-right">
<div className="counter">
<div className="counter-label">Episodes Recorded</div>
<div className="counter-value">{episodeCount}</div>
<button onClick={resetCounter} className="btn-reset">
Reset
</button>
</div>
</div>
</div>
{/* Delete Latest Episode Button */}
{!isRecording && !isInitializing && latestRepoId && (
<div className="delete-episode-section">
<button
onClick={deleteLatestEpisode}
className="btn-delete"
title="Delete the latest recorded episode from HuggingFace Hub"
>
Delete Latest Episode
</button>
<div className="delete-info">Will delete: {latestRepoId}</div>
</div>
)}
{/* Move to Zero Button */}
{robotsReady && !isRecording && !isInitializing && (
<div className="zero-position-section">
<button
onClick={moveToZero}
disabled={movingToZero}
className="btn-zero-large"
title="Move both leader and follower robots to zero position (2s)"
>
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
</button>
</div>
)}
{/* Error Display */}
{error && (
<div className="error-box">
{error}
</div>
)}
</section>
</div>
{/* Right Column: Camera Feeds */}
<div className="right-column">
<section className="panel cameras">
<h2>📹 Camera Views</h2>
{robotsReady || isRecording || isInitializing ? (
<div className="camera-layout">
{/* Base camera - full width */}
<div className="camera camera-base">
<h3>Base Camera</h3>
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
</div>
{/* Wrist cameras - side by side */}
<div className="camera-wrist-container">
<div className="camera camera-wrist">
<h3>Left Wrist</h3>
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
</div>
<div className="camera camera-wrist">
<h3>Right Wrist</h3>
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
</div>
</div>
</div>
) : (
<div className="camera-placeholder">
<p>📷 Camera feeds will appear when robots are set up</p>
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
</div>
)}
</section>
</div>
</div>
</main>
);
}
export default App;

View File

@@ -1,41 +0,0 @@
# OpenArms Web Recording Interface
A web interface for recording OpenArms datasets.
## Installation
```bash
cd examples/openarms_web_interface
npm install
```
## Usage
**Start everything with one command:**
```bash
./launch.sh
```
This will:
- Start the FastAPI backend on port 8000
- Start the React frontend on port 5173
- Show live logs from both services
Then open your browser to: **http://localhost:5173**
**Stop with:** `Ctrl+C`
---
## Workflow
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
2. Click **"Setup Robots"** to initialize (once at start)
3. Enter a **task description**
4. Click **"Start Recording"** to begin an episode
5. Click **"Stop Recording"** when done
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
---

View File

@@ -1,12 +0,0 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>OpenArms Recording Interface</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/main.jsx"></script>
</body>
</html>

View File

@@ -1,142 +0,0 @@
#!/bin/bash
# OpenArms Web Interface Launcher
# Starts Rerun viewer, FastAPI backend, and React frontend
set -e
# Colors for output
GREEN='\033[0;32m'
BLUE='\033[0;34m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m' # No Color
# Get script directory
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
echo ""
# Function to cleanup on exit
cleanup() {
echo ""
echo -e "${YELLOW}Shutting down services...${NC}"
# Kill all child processes
pkill -P $$ 2>/dev/null || true
# Kill specific services by port
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
echo -e "${GREEN}✓ Services stopped${NC}"
exit 0
}
# Register cleanup on script exit
trap cleanup EXIT INT TERM
# Check if required commands exist
command -v rerun >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
exit 1
}
command -v python >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'python' not found${NC}"
exit 1
}
command -v npm >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'npm' not found${NC}"
exit 1
}
# Check if node_modules exists
if [ ! -d "node_modules" ]; then
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
npm install
echo -e "${GREEN}✓ Dependencies installed${NC}"
echo ""
fi
echo -e "${GREEN}Starting services...${NC}"
echo ""
# 1. Start FastAPI backend (Rerun will start when recording begins)
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
cd "$SCRIPT_DIR"
# Use Python from current environment (if lerobot env is active, it will use that)
# Otherwise, check if we need to use conda run
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
# Already in lerobot environment
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
PYTHON_CMD="python"
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
# lerobot env exists but not active - use conda run
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
else
# Fall back to system python
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
PYTHON_CMD="python"
fi
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
BACKEND_PID=$!
sleep 3
if ps -p $BACKEND_PID > /dev/null; then
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
else
echo -e "${RED}✗ Failed to start backend${NC}"
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
exit 1
fi
echo ""
# 2. Start React frontend
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
cd "$SCRIPT_DIR"
npm run dev > /tmp/openarms_frontend.log 2>&1 &
FRONTEND_PID=$!
sleep 3
if ps -p $FRONTEND_PID > /dev/null; then
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
else
echo -e "${RED}✗ Failed to start frontend${NC}"
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
exit 1
fi
echo ""
# Display status
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
echo ""
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
echo ""
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
echo ""
echo -e "${YELLOW}Logs:${NC}"
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
echo ""
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
echo ""
# Keep script running and wait for any service to exit
wait

View File

@@ -1,7 +0,0 @@
import { createRoot } from 'react-dom/client'
import App from './App.jsx'
createRoot(document.getElementById('root')).render(
<App />
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,21 +0,0 @@
{
"name": "openarms-web-interface",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview"
},
"dependencies": {
"react": "^18.3.1",
"react-dom": "^18.3.1"
},
"devDependencies": {
"@types/react": "^18.3.12",
"@types/react-dom": "^18.3.1",
"@vitejs/plugin-react": "^4.3.4",
"vite": "^6.0.1"
}
}

View File

@@ -1,17 +0,0 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
// https://vite.dev/config/
export default defineConfig({
plugins: [react()],
server: {
port: 5173,
strictPort: false,
host: true,
open: false
},
build: {
outDir: 'dist',
sourcemap: true
}
})

File diff suppressed because it is too large Load Diff

View File

@@ -1,98 +0,0 @@
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]
return [i / fps for i in delta_indices]
output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")

View File

@@ -1,57 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -1,11 +0,0 @@
from lerobot.async_inference.configs import PolicyServerConfig
from lerobot.async_inference.policy_server import serve
host = ... # something like "127.0.0.1" if you're exposing to localhost
port = ... # something like 8080
config = PolicyServerConfig(
host=host,
port=port,
)
serve(config)

View File

@@ -1,55 +0,0 @@
import threading
from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.helpers import visualize_action_queue_size
from lerobot.async_inference.robot_client import RobotClient
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.robots.so100_follower import SO100FollowerConfig
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
# check the config.json on the Hub for the policy you are using to see the expected camera specs
camera_cfg = {
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
server_address = ... # something like "127.0.0.1:8080" if using localhost
# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
policy_type="act",
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
# 4. Create and start client
client = RobotClient(client_cfg)
# 5. Provide a textual description of the task
task = ...
if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()
try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)

View File

@@ -1,99 +0,0 @@
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]
return [i / fps for i in delta_indices]
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
policy = DiffusionPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")

View File

@@ -1,60 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_diffusion"
model = DiffusionPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(
model.config, model_id, dataset_stats=dataset_metadata.stats
)
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -1,67 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/pi0_base"
model = PI0Policy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -1,345 +0,0 @@
import multiprocessing as mp
import signal
from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
from lerobot.teleoperators.utils import TeleopEvents
LOG_EVERY = 10
SEND_EVERY = 10
def run_learner(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_learner: SACPolicy,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer,
lr: float = 3e-4,
batch_size: int = 32,
device: torch.device = "mps",
):
"""The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
for the actor to adopt."""
policy_learner.train()
policy_learner.to(device)
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
training_step = 0
while not shutdown_event.is_set():
# retrieve incoming transitions from the actor process
try:
transitions = transitions_queue.get(timeout=0.1)
for transition in transitions:
# HIL-SERL: Add ALL transitions to online buffer
online_buffer.add(**transition)
# HIL-SERL: Add ONLY human intervention transitions to offline buffer
is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
if is_intervention:
offline_buffer.add(**transition)
print(
f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
)
except Empty:
pass # No transitions available, continue
# Train if we have enough data
if len(online_buffer) >= policy_learner.config.online_step_before_learning:
# Sample from online buffer (autonomous + human data)
online_batch = online_buffer.sample(batch_size // 2)
# Sample from offline buffer (human demonstrations only, either precollected or at runtime)
offline_batch = offline_buffer.sample(batch_size // 2)
# Combine batches - this is the key HIL-SERL mechanism!
batch = {}
for key in online_batch:
if key in offline_batch:
batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
else:
batch[key] = online_batch[key]
loss, _ = policy_learner.forward(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_step += 1
if training_step % LOG_EVERY == 0:
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
pass
print("[LEARNER] Learner process finished")
def run_actor(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_actor: SACPolicy,
reward_classifier: Classifier,
env_cfg: HILSerlRobotEnvConfig,
device: torch.device = "mps",
output_directory: Path | None = None,
):
"""The actor process - interacts with environment and collects data.
The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
policy_actor.eval()
policy_actor.to(device)
reward_classifier.eval()
reward_classifier.to(device)
# Create robot environment inside the actor process
env, teleop_device = make_robot_env(env_cfg)
try:
for episode in range(MAX_EPISODES):
if shutdown_event.is_set():
break
obs, _info = env.reset()
episode_reward = 0.0
step = 0
episode_transitions = []
print(f"[ACTOR] Starting episode {episode + 1}")
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
next_obs, _env_reward, terminated, truncated, _info = env.step(action)
done = terminated or truncated
# Predict reward
policy_next_obs = make_policy_obs(next_obs, device=device)
reward = reward_classifier.predict_reward(policy_next_obs)
if reward >= 1.0 and not done: # success detected! halt episode
terminated = True
done = True
# In HIL-SERL, human interventions come from the teleop device
is_intervention = False
if hasattr(teleop_device, "get_teleop_events"):
# Real intervention detection from teleop device
teleop_events = teleop_device.get_teleop_events()
is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
# Store transition with intervention metadata
transition = {
"state": policy_obs,
"action": action,
"reward": float(reward) if hasattr(reward, "item") else reward,
"next_state": policy_next_obs,
"done": done,
"truncated": truncated,
"complementary_info": {
"is_intervention": is_intervention,
},
}
episode_transitions.append(transition)
episode_reward += reward
step += 1
obs = next_obs
if done:
break
# Send episode transitions to learner
transitions_queue.put_nowait(episode_transitions)
except KeyboardInterrupt:
print("[ACTOR] Interrupted by user")
finally:
# Clean up
if hasattr(env, "robot") and env.robot.is_connected:
env.robot.disconnect()
if teleop_device and hasattr(teleop_device, "disconnect"):
teleop_device.disconnect()
if output_directory is not None:
policy_actor.save_pretrained(output_directory)
print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
print("[ACTOR] Actor process finished")
def make_policy_obs(obs, device: torch.device = "cpu"):
return {
"observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
**{
f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
for k in obs["pixels"]
},
}
"""Main function - coordinates actor and learner processes."""
device = "mps" # or "cuda" or "cpu"
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
output_directory.mkdir(parents=True, exist_ok=True)
# find ports using lerobot-find-port
follower_port = ...
leader_port = ...
# the robot ids are used the load the right calibration files
follower_id = ...
leader_id = ...
# A pretrained model (to be used in-distribution!)
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
reward_classifier.to(device)
reward_classifier.eval()
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
# Create robot environment
env, teleop_device = make_robot_env(env_cfg)
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create SAC policy for action selection
policy_cfg = SACConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
# Online buffer: initialized from scratch
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
)
# Create communication channels between learner and actor processes
transitions_queue = mp.Queue(maxsize=10)
parameters_queue = mp.Queue(maxsize=2)
shutdown_event = mp.Event()
# Signal handler for graceful shutdown
def signal_handler(sig):
print(f"\nSignal {sig} received, shutting down...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Create processes
learner_process = mp.Process(
target=run_learner,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_learner,
online_replay_buffer,
offline_replay_buffer,
),
kwargs={"device": device}, # can run on accelerated hardware for training
)
actor_process = mp.Process(
target=run_actor,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_actor,
reward_classifier,
env_cfg,
output_directory,
),
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
)
learner_process.start()
actor_process.start()
try:
# Wait for actor to finish (it controls the episode loop)
actor_process.join()
shutdown_event.set()
learner_process.join(timeout=10)
except KeyboardInterrupt:
print("Main process interrupted")
shutdown_event.set()
actor_process.join(timeout=5)
learner_process.join(timeout=10)
finally:
if learner_process.is_alive():
learner_process.terminate()
if actor_process.is_alive():
actor_process.terminate()

View File

@@ -1,62 +0,0 @@
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
# Device to use for training
device = "mps" # or "cuda", or "cpu"
# Load the dataset used for training
repo_id = "lerobot/example_hil_serl_dataset"
dataset = LeRobotDataset(repo_id)
# Configure the policy to extract features from the image frames
camera_keys = dataset.meta.camera_keys
config = RewardClassifierConfig(
num_cameras=len(camera_keys),
device=device,
# backbone model to extract features from the image frames
model_name="microsoft/resnet-18",
)
# Make policy, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta)
optimizer = config.get_optimizer_preset().build(policy.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
# Instantiate a dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
total_loss = 0
total_accuracy = 0
for batch in dataloader:
# Preprocess the batch and move it to the correct device.
batch = preprocessor(batch)
# Forward pass
loss, output_dict = policy.forward(batch)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_accuracy += output_dict["accuracy"]
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_accuracy / len(dataloader)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
print("Training finished!")
# You can now save the trained policy.
policy.push_to_hub(classifier_id)

View File

@@ -1,66 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/smolvla_base"
model = SmolVLAPolicy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -1,10 +0,0 @@
from huggingface_hub import HfApi, list_datasets
api = HfApi()
datasets = list_datasets(author="lerobot-data-collection")
print('"[', end="")
i=0
for dataset in datasets:
if "three-folds-dataset" in dataset.id:
print("'" + dataset.id + "',", end="")
print(']"',)

View File

@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.1"
version = "0.3.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -62,7 +62,6 @@ dependencies = [
"datasets>=4.0.0,<4.2.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies
"setuptools>=71.0.0,<81.0.0",
@@ -74,15 +73,15 @@ dependencies = [
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pyserial>=3.5,<4.0",
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
"wandb>=0.20.0,<0.23.0",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.1.1,<2.0.0",
"rerun-sdk>=0.24.0,<0.27.0",
"gymnasium>=1.0.0",
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
# Support dependencies
"deepdiff>=7.0.1,<9.0.0",
@@ -97,15 +96,13 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
@@ -115,23 +112,17 @@ intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
# stretch = [
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
# ] # TODO: Currently not supported
# Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
"Pillow>=10.0.0,<13.0.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
@@ -145,12 +136,11 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
metaworld = ["metaworld==3.0.0"]
metaworld = ["metaworld>=3.0.0"]
# All
all = [
"lerobot[dynamixel]",
"lerobot[openarms]",
"lerobot[gamepad]",
"lerobot[hopejr]",
"lerobot[lekiwi]",
@@ -159,7 +149,6 @@ all = [
"lerobot[intelrealsense]",
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -244,6 +233,9 @@ exclude_dirs = [
"tests",
"benchmarks",
"src/lerobot/datasets/push_dataset_to_hub",
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
"src/lerobot/policies/pi0/conversion_scripts",
"src/lerobot/scripts/push_dataset_to_hub.py",
]
skips = ["B101", "B311", "B404", "B603", "B615"]
@@ -258,8 +250,6 @@ default.extend-ignore-identifiers-re = [
"pn",
"ser",
"ein",
"thw",
"inpt",
]
# TODO: Uncomment when ready to use
@@ -298,6 +288,7 @@ ignore_errors = true
[[tool.mypy.overrides]]
module = "lerobot.envs.*"
# Enable type checking only for the envs module
ignore_errors = false
@@ -305,22 +296,17 @@ ignore_errors = false
# module = "lerobot.utils.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.configs.*"
ignore_errors = false
# extra strictness for configs
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
# [[tool.mypy.overrides]]
# module = "lerobot.configs.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.optim.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.model.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.model.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.processor.*"
@@ -330,9 +316,9 @@ ignore_errors = false
# module = "lerobot.datasets.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.cameras.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.cameras.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"

View File

@@ -1,4 +1,3 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
@@ -13,62 +12,47 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
# via
# lerobot
# peft
accelerate==1.9.0
# via lerobot
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.12.15
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
# via
# starlette
# watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
attrs==25.3.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.1.0
av==15.0.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
blinker==1.9.0
# via flask
certifi==2025.7.14
# via
# requests
# sentry-sdk
cffi==2.0.0
cffi==1.17.1
# via pymunk
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.2
# via requests
click==8.3.0
click==8.2.1
# via
# uvicorn
# flask
# wandb
cloudpickle==3.1.1
# via
# gymnasium
# libero
cmake==4.1.0
# via gymnasium
cmake==4.0.3
# via lerobot
cmeel==0.57.3
# via
@@ -110,27 +94,27 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
coverage[toml]==7.10.1
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==3.6.0
# via lerobot
debugpy==1.8.17
debugpy==1.8.15
# via lerobot
decorator==5.2.1
# via ipython
deepdiff==8.6.1
deepdiff==8.5.0
# via lerobot
diffusers==0.35.2
diffusers==0.34.0
# via lerobot
dill==0.4.0
dill==0.3.8
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -138,45 +122,29 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.8.4
dynamixel-sdk==3.7.31
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via
# lerobot
# libero
# via lerobot
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
# via mujoco
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
executing==2.2.0
# via stack-data
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.18.0
# via
# datasets
# diffusers
@@ -184,25 +152,24 @@ filelock==3.20.0
# torch
# transformers
# virtualenv
fonttools==4.60.1
flask==3.1.1
# via lerobot
fonttools==4.59.0
# via matplotlib
frozenlist==1.8.0
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2025.3.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
glfw==2.10.0
glfw==2.9.0
# via
# dm-control
# mujoco
@@ -210,79 +177,61 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
# reachy2-sdk-api
gym-aloha==0.1.3
# via lerobot
gym-hil==0.1.13
gym-aloha==0.1.1
# via lerobot
gym-pusht==0.1.6
gym-hil==0.1.10
# via lerobot
gymnasium==1.2.1
gym-pusht==0.1.5
# via lerobot
gym-xarm==0.1.1
# via lerobot
gymnasium==0.29.1
# via
# gym-aloha
# gym-hil
# gym-pusht
# gym-xarm
# gymnasium-robotics
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
hebi-py==2.11.0
# via lerobot
# pettingzoo
gymnasium-robotics==1.2.4
# via gym-xarm
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-xet==1.1.5
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
huggingface-hub[cli,hf-transfer]==0.34.3
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
identify==2.6.12
# via pre-commit
idna==3.11
idna==3.10
# via
# anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
# gymnasium-robotics
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
# via imageio
importlib-metadata==8.7.0
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
iniconfig==2.1.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -290,71 +239,50 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
itsdangerous==2.2.0
# via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
# via torch
# via
# flask
# gymnasium-robotics
# torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
kiwisolver==1.4.8
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
lxml==6.0.2
lxml==6.0.0
# via dm-control
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
markupsafe==3.0.3
markupsafe==3.0.2
# via
# flask
# jinja2
# werkzeug
matplotlib==3.10.7
# via
# lerobot
# libero
matplotlib-inline==0.2.1
matplotlib==3.10.5
# via lerobot
matplotlib-inline==0.1.7
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
metaworld==3.0.0
# via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==2.3.7
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# metaworld
# robosuite
multidict==6.7.0
# gym-xarm
# gymnasium-robotics
multidict==6.6.3
# via
# aiohttp
# yarl
@@ -362,25 +290,17 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
@@ -389,43 +309,25 @@ numpy==2.2.6
# dm-env
# dm-tree
# gymnasium
# h5py
# hebi-py
# gymnasium-robotics
# imageio
# labmaze
# libero
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
# peft
# pyquaternion
# reachy2-sdk
# pettingzoo
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
# via gym-pusht
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -435,63 +337,53 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.3
pandas==2.3.1
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.4
# via jedi
peft==0.17.1
# via lerobot
pettingzoo==1.24.3
# via gymnasium-robotics
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==11.3.0
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
platformdirs==4.5.0
platformdirs==4.3.8
# via
# jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.2.0
# via lerobot
prompt-toolkit==3.0.52
prompt-toolkit==3.0.51
# via
# inquirerpy
# ipython
propcache==0.4.1
propcache==0.3.2
# via
# aiohttp
# yarl
@@ -500,17 +392,11 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.0.0
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -519,13 +405,11 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==2.22
# via cffi
pydantic==2.12.3
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic==2.11.7
# via wandb
pydantic-core==2.33.2
# via pydantic
pygame==2.6.1
# via
@@ -540,42 +424,40 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.2.12
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyobjc-core==12.0
pyobjc-core==11.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-cocoa
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-applicationservices==12.0
pyobjc-framework-applicationservices==11.1
# via pynput
pyobjc-framework-cocoa==12.0
pyobjc-framework-cocoa==11.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-coretext==12.0
pyobjc-framework-coretext==11.1
# via pyobjc-framework-applicationservices
pyobjc-framework-quartz==12.0
pyobjc-framework-quartz==11.1
# via
# pynput
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
pyopengl==3.1.10
pyopengl==3.1.9
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.2.3
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2-macosx==2.54.2
# via lerobot
pyserial==3.5
@@ -583,14 +465,12 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
pytest==8.4.2
pytest==8.4.1
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
# teleop
pytest-cov==7.0.0
pytest-cov==6.2.1
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -598,73 +478,46 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
python-dotenv==1.1.1
# via uvicorn
pytz==2025.2
# via pandas
pyyaml==6.0.3
pyyaml==6.0.2
# via
# accelerate
# datasets
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
pyzmq==27.1.0
pyzmq==27.0.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2025.7.34
# via
# diffusers
# transformers
requests==2.32.5
requests==2.32.4
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.22.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
safetensors==0.5.3
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
@@ -673,12 +526,10 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
sentry-sdk==2.34.1
# via wandb
shapely==2.1.2
shapely==2.1.1
# via gym-pusht
six==1.17.0
# via
@@ -686,106 +537,64 @@ six==1.17.0
# python-dateutil
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
# via lerobot
tifffile==2025.5.10
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.21.4
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
tomli==2.2.1
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
# via
# accelerate
# lerobot
# peft
# robomimic
# thop
# timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
# via lerobot
tornado==6.5.1
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
typing-extensions==4.15.0
transformers==4.51.3
# via lerobot
typing-extensions==4.14.1
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via pandas
@@ -795,36 +604,22 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==20.32.0
# via pre-commit
wandb==0.21.4
# via
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wandb==0.21.0
# via lerobot
wcwidth==0.2.13
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
# via uvicorn
werkzeug==3.1.3
# via tensorboard
wrapt==2.0.0
# via flask
wrapt==1.17.2
# via dm-tree
xxhash==3.6.0
xxhash==3.5.0
# via datasets
yarl==1.22.0
yarl==1.20.1
# via aiohttp
zipp==3.23.0
# via
# etils
# importlib-metadata
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@@ -13,62 +13,47 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
# via
# lerobot
# peft
accelerate==1.9.0
# via lerobot
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.12.15
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
# via
# starlette
# watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
attrs==25.3.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.1.0
av==15.0.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
blinker==1.9.0
# via flask
certifi==2025.7.14
# via
# requests
# sentry-sdk
cffi==2.0.0
cffi==1.17.1
# via pymunk
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.2
# via requests
click==8.3.0
click==8.2.1
# via
# uvicorn
# flask
# wandb
cloudpickle==3.1.1
# via
# gymnasium
# libero
cmake==4.1.0
# via gymnasium
cmake==4.0.3
# via lerobot
cmeel==0.57.3
# via
@@ -110,29 +95,27 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
coverage[toml]==7.10.1
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==3.6.0
# via lerobot
debugpy==1.8.17
debugpy==1.8.15
# via lerobot
decorator==5.2.1
# via ipython
decord==0.6.0
deepdiff==8.5.0
# via lerobot
deepdiff==8.6.1
diffusers==0.34.0
# via lerobot
diffusers==0.35.2
# via lerobot
dill==0.4.0
dill==0.3.8
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -140,48 +123,31 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.8.4
dynamixel-sdk==3.7.31
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via
# flash-attn
# lerobot
# libero
# via lerobot
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
# via mujoco
evdev==1.9.2
# via pynput
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
executing==2.2.0
# via stack-data
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.18.0
# via
# datasets
# diffusers
@@ -189,27 +155,24 @@ filelock==3.20.0
# torch
# transformers
# virtualenv
flash-attn==2.8.3
flask==3.1.1
# via lerobot
fonttools==4.60.1
fonttools==4.59.0
# via matplotlib
frozenlist==1.8.0
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2025.3.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
glfw==2.10.0
glfw==2.9.0
# via
# dm-control
# mujoco
@@ -217,79 +180,61 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
# reachy2-sdk-api
gym-aloha==0.1.3
# via lerobot
gym-hil==0.1.13
gym-aloha==0.1.1
# via lerobot
gym-pusht==0.1.6
gym-hil==0.1.10
# via lerobot
gymnasium==1.2.1
gym-pusht==0.1.5
# via lerobot
gym-xarm==0.1.1
# via lerobot
gymnasium==0.29.1
# via
# gym-aloha
# gym-hil
# gym-pusht
# gym-xarm
# gymnasium-robotics
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
hebi-py==2.11.0
# via lerobot
# pettingzoo
gymnasium-robotics==1.2.4
# via gym-xarm
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-xet==1.1.5
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
huggingface-hub[cli,hf-transfer]==0.34.3
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
identify==2.6.12
# via pre-commit
idna==3.11
idna==3.10
# via
# anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
# gymnasium-robotics
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
# via imageio
importlib-metadata==8.7.0
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
iniconfig==2.1.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -297,71 +242,50 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
itsdangerous==2.2.0
# via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
# via torch
# via
# flask
# gymnasium-robotics
# torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
kiwisolver==1.4.8
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
lxml==6.0.2
lxml==6.0.0
# via dm-control
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
markupsafe==3.0.3
markupsafe==3.0.2
# via
# flask
# jinja2
# werkzeug
matplotlib==3.10.7
# via
# lerobot
# libero
matplotlib-inline==0.2.1
matplotlib==3.10.5
# via lerobot
matplotlib-inline==0.1.7
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
metaworld==3.0.0
# via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==2.3.7
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# metaworld
# robosuite
multidict==6.7.0
# gym-xarm
# gymnasium-robotics
multidict==6.6.3
# via
# aiohttp
# yarl
@@ -369,63 +293,42 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
# decord
# diffusers
# dm-control
# dm-env
# dm-tree
# gymnasium
# h5py
# hebi-py
# gymnasium-robotics
# imageio
# labmaze
# libero
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
# peft
# pyquaternion
# reachy2-sdk
# pettingzoo
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
nvidia-cublas-cu12==12.6.4.1
# via
# nvidia-cudnn-cu12
@@ -463,14 +366,8 @@ nvidia-nvjitlink-cu12==12.6.85
# torch
nvidia-nvtx-cu12==12.6.77
# via torch
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
# via gym-pusht
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -480,63 +377,53 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.3
pandas==2.3.1
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.4
# via jedi
peft==0.17.1
# via lerobot
pettingzoo==1.24.3
# via gymnasium-robotics
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==11.3.0
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
platformdirs==4.5.0
platformdirs==4.3.8
# via
# jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.2.0
# via lerobot
prompt-toolkit==3.0.52
prompt-toolkit==3.0.51
# via
# inquirerpy
# ipython
propcache==0.4.1
propcache==0.3.2
# via
# aiohttp
# yarl
@@ -545,17 +432,11 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.0.0
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -564,13 +445,11 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==2.22
# via cffi
pydantic==2.12.3
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic==2.11.7
# via wandb
pydantic-core==2.33.2
# via pydantic
pygame==2.6.1
# via
@@ -585,22 +464,20 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.2.12
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyopengl==3.1.10
pyopengl==3.1.9
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.2.3
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2==2.56.5.9235
# via lerobot
pyserial==3.5
@@ -608,14 +485,12 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
pytest==8.4.2
pytest==8.4.1
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
# teleop
pytest-cov==7.0.0
pytest-cov==6.2.1
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -623,75 +498,48 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
python-dotenv==1.1.1
# via uvicorn
python-xlib==0.33
# via pynput
pytz==2025.2
# via pandas
pyyaml==6.0.3
pyyaml==6.0.2
# via
# accelerate
# datasets
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
pyzmq==27.1.0
pyzmq==27.0.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2025.7.34
# via
# diffusers
# transformers
requests==2.32.5
requests==2.32.4
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.22.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
safetensors==0.5.3
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
@@ -700,12 +548,10 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
sentry-sdk==2.34.1
# via wandb
shapely==2.1.2
shapely==2.1.1
# via gym-pusht
six==1.17.0
# via
@@ -714,109 +560,66 @@ six==1.17.0
# python-xlib
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
# via lerobot
tifffile==2025.5.10
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.21.4
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
tomli==2.2.1
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
# via
# accelerate
# flash-attn
# lerobot
# peft
# robomimic
# thop
# timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
# via lerobot
tornado==6.5.1
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
transformers==4.51.3
# via lerobot
triton==3.3.1
# via torch
typing-extensions==4.15.0
typing-extensions==4.14.1
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via pandas
@@ -826,36 +629,22 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==20.32.0
# via pre-commit
wandb==0.21.4
# via
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wandb==0.21.0
# via lerobot
wcwidth==0.2.13
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
# via uvicorn
werkzeug==3.1.3
# via tensorboard
wrapt==2.0.0
# via flask
wrapt==1.17.2
# via dm-tree
xxhash==3.6.0
xxhash==3.5.0
# via datasets
yarl==1.22.0
yarl==1.20.1
# via aiohttp
zipp==3.23.0
# via
# etils
# importlib-metadata
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@@ -1,9 +1,9 @@
# requirements.in
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
-e .[all]

View File

@@ -16,7 +16,7 @@ import logging
import logging.handlers
import os
import time
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
import torch
@@ -268,7 +268,6 @@ class RemotePolicyConfig:
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
rename_map: dict[str, str] = field(default_factory=dict)
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:

View File

@@ -159,10 +159,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
self.preprocessor, self.postprocessor = make_pre_post_processors(
self.policy.config,
pretrained_path=policy_specs.pretrained_name_or_path,
preprocessor_overrides={
"device_processor": device_override,
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
},
preprocessor_overrides={"device_processor": device_override},
postprocessor_overrides={"device_processor": device_override},
)

View File

@@ -17,7 +17,7 @@
import abc
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
import numpy as np
from .configs import CameraConfig, ColorMode
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""Capture and return a single frame from the camera.
Args:
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
"""Asynchronously capture and return a single frame from the camera.
Args:

View File

@@ -18,7 +18,7 @@ import abc
from dataclasses import dataclass
from enum import Enum
import draccus # type: ignore # TODO: add type stubs for draccus
import draccus
class ColorMode(str, Enum):
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
fps: int | None = None
width: int | None = None
height: int | None = None
@property
def type(self) -> str:
return str(self.get_choice_name(self.__class__))
return self.get_choice_name(self.__class__)

View File

@@ -14,5 +14,3 @@
from .camera_opencv import OpenCVCamera
from .configuration_opencv import OpenCVCameraConfig
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]

View File

@@ -25,12 +25,11 @@ from pathlib import Path
from threading import Event, Lock, Thread
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import cv2
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@@ -122,7 +121,7 @@ class OpenCVCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -141,7 +140,7 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
def connect(self, warmup: bool = True) -> None:
def connect(self, warmup: bool = True):
"""
Connects to the OpenCV camera specified in the configuration.
@@ -181,14 +180,12 @@ class OpenCVCamera(Camera):
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
Applies the specified FPS, width, and height settings to the connected camera.
This method attempts to set the camera properties via OpenCV. It checks if
the camera successfully applied the settings and raises an error if not.
FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
Args:
fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
fps: The desired frames per second. If None, the setting is skipped.
width: The desired capture width. If None, the setting is skipped.
height: The desired capture height. If None, the setting is skipped.
@@ -202,11 +199,10 @@ class OpenCVCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
self._validate_fourcc()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
@@ -220,56 +216,18 @@ class OpenCVCamera(Camera):
else:
self._validate_width_and_height()
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
def _validate_fps(self) -> None:
"""Validates and sets the camera's frames per second (FPS)."""
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.fps is None:
raise ValueError(f"{self} FPS is not set")
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
# Use math.isclose for robust float comparison
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
def _validate_fourcc(self) -> None:
"""Validates and sets the camera's FOURCC code."""
fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
# Convert actual FOURCC code back to string for comparison
actual_fourcc_code_int = int(actual_fourcc_code)
actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
if not success or actual_fourcc != self.config.fourcc:
logger.warning(
f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
f"Continuing with default format."
)
def _validate_width_and_height(self) -> None:
"""Validates and sets the camera's frame capture width and height."""
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.capture_width is None or self.capture_height is None:
raise ValueError(f"{self} capture_width or capture_height is not set")
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
@@ -300,12 +258,11 @@ class OpenCVCamera(Camera):
"""
found_cameras_info = []
targets_to_scan: list[str | int]
if platform.system() == "Linux":
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
targets_to_scan = [str(p) for p in possible_paths]
else:
targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
targets_to_scan = list(range(MAX_OPENCV_INDEX))
for target in targets_to_scan:
camera = cv2.VideoCapture(target)
@@ -314,12 +271,6 @@ class OpenCVCamera(Camera):
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
default_fps = camera.get(cv2.CAP_PROP_FPS)
default_format = camera.get(cv2.CAP_PROP_FORMAT)
# Get FOURCC code and convert to string
default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
default_fourcc_code_int = int(default_fourcc_code)
default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
camera_info = {
"name": f"OpenCV Camera @ {target}",
"type": "OpenCV",
@@ -327,7 +278,6 @@ class OpenCVCamera(Camera):
"backend_api": camera.getBackendName(),
"default_stream_profile": {
"format": default_format,
"fourcc": default_fourcc,
"width": default_width,
"height": default_height,
"fps": default_fps,
@@ -339,7 +289,7 @@ class OpenCVCamera(Camera):
return found_cameras_info
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
@@ -367,9 +317,6 @@ class OpenCVCamera(Camera):
start_time = time.perf_counter()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
ret, frame = self.videocapture.read()
if not ret or frame is None:
@@ -382,7 +329,7 @@ class OpenCVCamera(Camera):
return processed_frame
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
@@ -425,7 +372,7 @@ class OpenCVCamera(Camera):
return processed_image
def _read_loop(self) -> None:
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
@@ -436,9 +383,6 @@ class OpenCVCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -475,7 +419,7 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
@@ -518,7 +462,7 @@ class OpenCVCamera(Camera):
return frame
def disconnect(self) -> None:
def disconnect(self):
"""
Disconnects from the camera and cleans up resources.

View File

@@ -17,8 +17,6 @@ from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
@CameraConfig.register_subclass("opencv")
@dataclass
@@ -35,9 +33,8 @@ class OpenCVCameraConfig(CameraConfig):
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
# Advanced configurations with FOURCC format
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
# Advanced configurations
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
```
Attributes:
@@ -49,21 +46,17 @@ class OpenCVCameraConfig(CameraConfig):
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
- FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
- Setting FOURCC can help achieve higher frame rates on some cameras.
"""
index_or_path: int | Path
color_mode: ColorMode = ColorMode.RGB
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
def __post_init__(self) -> None:
def __post_init__(self):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
@@ -78,8 +71,3 @@ class OpenCVCameraConfig(CameraConfig):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(
f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
)

View File

@@ -16,8 +16,6 @@ from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
@@ -64,7 +62,7 @@ class Reachy2CameraConfig(CameraConfig):
port: int = 50065
# use_depth: bool = False
def __post_init__(self) -> None:
def __post_init__(self):
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (

View File

@@ -23,17 +23,13 @@ import time
from threading import Event, Lock, Thread
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
CameraManager,
)
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
from lerobot.utils.errors import DeviceNotConnectedError
@@ -77,7 +73,7 @@ class Reachy2Camera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -87,17 +83,13 @@ class Reachy2Camera(Camera):
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
return bool(
self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
)
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
elif self.config.name == "depth":
return bool(
self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
)
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
def connect(self, warmup: bool = True) -> None:
def connect(self, warmup: bool = True):
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
@@ -139,7 +131,7 @@ class Reachy2Camera(Camera):
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
@@ -160,7 +152,7 @@ class Reachy2Camera(Camera):
start_time = time.perf_counter()
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
frame = None
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -187,7 +179,7 @@ class Reachy2Camera(Camera):
return frame
def _read_loop(self) -> None:
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
@@ -198,9 +190,6 @@ class Reachy2Camera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -237,7 +226,7 @@ class Reachy2Camera(Camera):
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
@@ -280,7 +269,7 @@ class Reachy2Camera(Camera):
return frame
def disconnect(self) -> None:
def disconnect(self):
"""
Stops the background read thread (if running).

View File

@@ -21,12 +21,11 @@ import time
from threading import Event, Lock, Thread
from typing import Any
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
import cv2
import numpy as np
try:
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
import pyrealsense2 as rs
except Exception as e:
logging.info(f"Could not import realsense: {e}")
@@ -133,7 +132,7 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -151,7 +150,7 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
def connect(self, warmup: bool = True) -> None:
def connect(self, warmup: bool = True):
"""
Connects to the RealSense camera specified in the configuration.
@@ -265,7 +264,7 @@ class RealSenseCamera(Camera):
serial_number = str(found_devices[0]["serial_number"])
return serial_number
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
def _configure_rs_pipeline_config(self, rs_config):
"""Creates and configures the RealSense pipeline configuration object."""
rs.config.enable_device(rs_config, self.serial_number)
@@ -294,9 +293,6 @@ class RealSenseCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
if self.fps is None:
@@ -312,7 +308,7 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -340,9 +336,6 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -358,7 +351,7 @@ class RealSenseCamera(Camera):
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
"""
Reads a single frame (color) synchronously from the camera.
@@ -383,9 +376,6 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -402,8 +392,8 @@ class RealSenseCamera(Camera):
return color_image_processed
def _postprocess_image(
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
) -> NDArray[Any]:
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
) -> np.ndarray:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
@@ -448,7 +438,7 @@ class RealSenseCamera(Camera):
return processed_image
def _read_loop(self) -> None:
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
@@ -459,9 +449,6 @@ class RealSenseCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
@@ -487,7 +474,7 @@ class RealSenseCamera(Camera):
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
def _stop_read_thread(self):
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
@@ -499,7 +486,7 @@ class RealSenseCamera(Camera):
self.stop_event = None
# NOTE(Steven): Missing implementation for depth for now
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame data (color) asynchronously.
@@ -542,7 +529,7 @@ class RealSenseCamera(Camera):
return frame
def disconnect(self) -> None:
def disconnect(self):
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.

View File

@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
def __post_init__(self) -> None:
def __post_init__(self):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."

View File

@@ -53,14 +53,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import cv2
if rotation == Cv2Rotation.ROTATE_90:
return int(cv2.ROTATE_90_CLOCKWISE)
return cv2.ROTATE_90_CLOCKWISE
elif rotation == Cv2Rotation.ROTATE_180:
return int(cv2.ROTATE_180)
return cv2.ROTATE_180
elif rotation == Cv2Rotation.ROTATE_270:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
return cv2.ROTATE_90_COUNTERCLOCKWISE
else:
return None
@@ -69,8 +69,8 @@ def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)
return cv2.CAP_ANY

View File

@@ -57,7 +57,7 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
def __post_init__(self) -> None:
def __post_init__(self):
if self.batch_size > self.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import datetime as dt
import logging
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from lerobot import envs, policies # noqa: F401
@@ -22,8 +22,6 @@ from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
logger = getLogger(__name__)
@dataclass
class EvalPipelineConfig:
@@ -36,31 +34,25 @@ class EvalPipelineConfig:
output_dir: Path | None = None
job_name: str | None = None
seed: int | None = 1000
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
self.policy.pretrained_path = policy_path
else:
logger.warning(
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
self.job_name = f"{self.policy.type}"
else:
self.job_name = (
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
)
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.output_dir:
now = dt.datetime.now()

View File

@@ -16,19 +16,14 @@ import inspect
import pkgutil
import sys
from argparse import ArgumentError
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Sequence
from functools import wraps
from pathlib import Path
from pkgutil import ModuleInfo
from types import ModuleType
from typing import Any, TypeVar, cast
import draccus
from lerobot.utils.utils import has_method
F = TypeVar("F", bound=Callable[..., object])
PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
@@ -65,7 +60,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
return None
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
"""Parse plugin-related arguments from command-line arguments.
This function extracts arguments from command-line arguments that match a specified suffix pattern.
@@ -132,7 +127,7 @@ def load_plugin(plugin_path: str) -> None:
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
) from e
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
def iter_namespace(ns_pkg):
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
try:
@@ -153,8 +148,6 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
if args is None:
return []
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
@@ -178,8 +171,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
if isinstance(fields_to_filter, str):
fields_to_filter = [fields_to_filter]
filtered_args = [] if args is None else list(args)
filtered_args = args
for field in fields_to_filter:
if get_path_arg(field, args):
if get_type_arg(field, args):
@@ -192,7 +184,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
def wrap(config_path: Path | None = None):
"""
HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
@@ -203,9 +195,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
from the CLI '.type' arguments
"""
def wrapper_outer(fn: F) -> F:
def wrapper_outer(fn):
@wraps(fn)
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
def wrapper_inner(*args, **kwargs):
argspec = inspect.getfullargspec(fn)
argtype = argspec.annotations[argspec.args[0]]
if len(args) > 0 and type(args[0]) is argtype:
@@ -233,6 +225,6 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
response = fn(cfg, *args, **kwargs)
return response
return cast(F, wrapper_inner)
return wrapper_inner
return cast(Callable[[F], F], wrapper_outer)
return wrapper_outer

View File

@@ -14,12 +14,12 @@
import abc
import builtins
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from typing import Any, TypeVar
from typing import TypeVar
import draccus
from huggingface_hub import hf_hub_download
@@ -34,11 +34,10 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
T = TypeVar("T", bound="PreTrainedConfig")
logger = getLogger(__name__)
@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
Base configuration class for policy models.
@@ -58,12 +57,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
push_to_hub: bool = True
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
@@ -74,41 +73,38 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: Path | None = None
pretrained_path: str | None = None
def __post_init__(self) -> None:
def __post_init__(self):
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logger.warning(
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property
def type(self) -> str:
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
return self.get_choice_name(self.__class__)
@property
@abc.abstractmethod
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
def observation_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
def action_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
def reward_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractmethod
@@ -158,13 +154,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**policy_kwargs: Any,
**policy_kwargs,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -172,7 +168,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
@@ -198,9 +194,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with open(config_file) as f:
config = json.load(f)

View File

@@ -16,7 +16,6 @@ import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import draccus
from huggingface_hub import hf_hub_download
@@ -64,18 +63,18 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
checkpoint_path: Path | None = field(init=False, default=None)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def validate(self) -> None:
def __post_init__(self):
self.checkpoint_path = None
def validate(self):
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
@@ -83,22 +82,14 @@ class TrainPipelineConfig(HubMixin):
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
self.checkpoint_path = policy_dir.parent
if self.policy is None:
raise ValueError(
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
)
policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent
if not self.job_name:
if self.env is None:
@@ -135,8 +126,8 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
def to_dict(self) -> dict:
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
@@ -148,13 +139,13 @@ class TrainPipelineConfig(HubMixin):
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs: Any,
**kwargs,
) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -190,6 +181,4 @@ class TrainPipelineConfig(HubMixin):
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset

View File

@@ -42,4 +42,4 @@ class NormalizationMode(str, Enum):
@dataclass
class PolicyFeature:
type: FeatureType
shape: tuple[int, ...]
shape: tuple

View File

@@ -22,13 +22,11 @@ from pathlib import Path
import datasets
import numpy as np
import os
import packaging.version
import pandas as pd
import PIL.Image
import pyarrow as pa
import pyarrow.parquet as pq
from concurrent.futures import ProcessPoolExecutor
import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
@@ -688,7 +686,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = None
self.writer = None
self.latest_episode = None
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
self.root.mkdir(exist_ok=True, parents=True)
@@ -711,8 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
@@ -839,14 +835,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video files."""
"""Check if the cached dataset contains all requested episodes."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
# Get available episode indices from cached dataset
available_episodes = {
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
for ep_idx in self.hf_dataset.unique("episode_index")
for ep_idx in self.hf_dataset["episode_index"]
}
# Determine requested episodes
@@ -858,18 +854,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
if not requested_episodes.issubset(available_episodes):
return False
# Check if all required video files exist
if len(self.meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
return True
return requested_episodes.issubset(available_episodes)
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
@@ -1151,9 +1136,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
use_batched_encoding = self.batch_encoding_size > 1
if has_video_keys and not use_batched_encoding:
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
for (video_key, video_path) in zip(self.meta.video_keys, video_paths):
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# `meta.save_episode` need to be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
@@ -1247,7 +1231,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
global_frame_index = 0
self._current_file_start_frame = 0
# However, if the episodes already exists
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
@@ -1259,7 +1242,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
self._current_file_start_frame = global_frame_index
else:
# Retrieve information from the latest parquet file
latest_ep = self.latest_episode
@@ -1270,7 +1252,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_file_size_in_mb(latest_path)
frames_in_current_file = global_frame_index - self._current_file_start_frame
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
av_size_per_frame = (
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
)
@@ -1284,7 +1266,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
self._close_writer()
self._writer_closed_for_reading = False
self._current_file_start_frame = global_frame_index
ep_dict["data/chunk_index"] = chunk_idx
ep_dict["data/file_index"] = file_idx
@@ -1318,12 +1299,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
return metadata
def _save_episode_video(self, video_key: str, episode_index: int, video_path: str | Path | None = None) -> dict:
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
# Encode episode frames into a temporary video
if video_path is None:
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
else:
ep_path = video_path
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
@@ -1447,22 +1425,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
shutil.rmtree(img_dir)
return temp_path
def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index):
temp_paths = []
img_dirs = []
for video_key in video_keys:
temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4")
img_dirs.append(self._get_image_file_dir(episode_index, video_key))
fps = [self.fps]*len(video_keys)
with ProcessPoolExecutor(max_workers=len(video_keys)) as executor:
executor.map(encode_video_frames,img_dirs,temp_paths,fps)
for img_dir in img_dirs:
shutil.rmtree(img_dir)
return temp_paths
@classmethod
def create(
cls,
@@ -1510,7 +1472,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None
obj._current_file_start_frame = None
# Initialize tracking for incremental recording
obj._lazy_loading = False
obj._recorded_frames = 0

View File

@@ -206,11 +206,6 @@ class ImageTransformsConfig:
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
"affine": ImageTransformConfig(
weight=1.0,
type="RandomAffine",
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
),
}
)
@@ -222,8 +217,6 @@ def make_transform_from_config(cfg: ImageTransformConfig):
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")

View File

@@ -98,7 +98,7 @@ OLD
videos/chunk-000/CAMERA/episode_000000.mp4
NEW
videos/CAMERA/chunk-000/file_000.mp4
videos/chunk-000/file_000.mp4
-------------------------
OLD
episodes.jsonl

View File

@@ -310,7 +310,7 @@ def encode_video_frames(
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
overwrite: bool = True,
overwrite: bool = False,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
# Check encoder availability
@@ -342,8 +342,8 @@ def encode_video_frames(
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
dummy_image = Image.open(input_list[0])
width, height = dummy_image.size
# Define video codec options
video_options = {}
@@ -354,9 +354,6 @@ def encode_video_frames(
if crf is not None:
video_options["crf"] = str(crf)
#TEMPORARY FIX
video_options["preset"] = "12"
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
@@ -376,12 +373,11 @@ def encode_video_frames(
# Loop through input frames and encode them
for input_data in input_list:
with Image.open(input_data) as input_image:
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
input_image = Image.open(input_data).convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
# Flush the encoder
packet = output_stream.encode()

View File

@@ -37,16 +37,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
def package_name(self) -> str:
"""Package name to import if environment not found in gym registry"""
return f"gym_{self.type}"
@property
def gym_id(self) -> str:
"""ID string used in gym.make() to instantiate the environment"""
return f"{self.package_name}/{self.task}"
@property
@abc.abstractmethod
def gym_kwargs(self) -> dict:

View File

@@ -16,7 +16,6 @@
import importlib
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
@@ -85,24 +84,17 @@ def make_env(
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
gym_handle = f"{package_name}/{cfg.task}"
def _make_one():
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)

View File

@@ -22,18 +22,18 @@ class RobotKinematics:
self,
urdf_path: str,
target_frame_name: str = "gripper_frame_link",
joint_names: list[str] | None = None,
joint_names: list[str] = None,
):
"""
Initialize placo-based kinematics solver.
Args:
urdf_path (str): Path to the robot URDF file
target_frame_name (str): Name of the end-effector frame in the URDF
joint_names (list[str] | None): List of joint names to use for the kinematics solver
urdf_path: Path to the robot URDF file
target_frame_name: Name of the end-effector frame in the URDF
joint_names: List of joint names to use for the kinematics solver
"""
try:
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
import placo
except ImportError as e:
raise ImportError(
"placo is required for RobotKinematics. "
@@ -52,7 +52,7 @@ class RobotKinematics:
# Initialize frame task for IK
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray:
def forward_kinematics(self, joint_pos_deg):
"""
Compute forward kinematics for given joint configuration given the target frame name in the constructor.
@@ -77,12 +77,8 @@ class RobotKinematics:
return self.robot.get_T_world_frame(self.target_frame_name)
def inverse_kinematics(
self,
current_joint_pos: np.ndarray,
desired_ee_pose: np.ndarray,
position_weight: float = 1.0,
orientation_weight: float = 0.01,
) -> np.ndarray:
self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
):
"""
Compute inverse kinematics using placo solver.

View File

@@ -14,11 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .motors_bus import (
Motor,
MotorCalibration,
MotorNormMode,
MotorsBus, # Backward compatibility (alias for SerialMotorsBus)
MotorsBusBase,
SerialMotorsBus,
)
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus

View File

@@ -1,905 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(pepijn): add license of: https://github.com/cmjang/DM_Control_Python?tab=MIT-1-ov-file#readme
import logging
import time
from contextlib import contextmanager
from copy import deepcopy
from functools import cached_property
from typing import Dict, List, Optional, Tuple, Union
import can
import numpy as np
from lerobot.motors import Motor, MotorCalibration, MotorNormMode, MotorsBusBase
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up
from .tables import (
AVAILABLE_BAUDRATES,
CAN_CMD_DISABLE,
CAN_CMD_ENABLE,
CAN_CMD_REFRESH,
CAN_CMD_SET_ZERO,
CAN_PARAM_ID,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA,
MotorType,
)
logger = logging.getLogger(__name__)
NameOrID = Union[str, int]
Value = Union[int, float]
class DamiaoMotorsBus(MotorsBusBase):
"""
The Damiao implementation for a MotorsBus using CAN bus communication.
This class uses python-can for CAN bus communication with Damiao motors.
For more info, see:
- python-can documentation: https://python-can.readthedocs.io/en/stable/
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
"""
# CAN-specific settings
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
# Motor configuration
model_resolution_table = deepcopy(MODEL_RESOLUTION)
normalized_data = deepcopy(NORMALIZED_DATA)
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
can_interface: str = "auto",
use_can_fd: bool = True,
bitrate: int = 1000000,
data_bitrate: int | None = 5000000,
):
"""
Initialize the Damiao motors bus.
Args:
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
motors: Dictionary mapping motor names to Motor objects
calibration: Optional calibration data
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
self.use_can_fd = use_can_fd
self.bitrate = bitrate
self.data_bitrate = data_bitrate
self.canbus = None
self._is_connected = False
# Map motor names to CAN IDs
self._motor_can_ids = {}
self._recv_id_to_motor = {}
# Store motor types and recv IDs
self._motor_types = {}
for name, motor in self.motors.items():
if hasattr(motor, "motor_type"):
self._motor_types[name] = motor.motor_type
else:
# Default to DM4310 if not specified
self._motor_types[name] = MotorType.DM4310
# Map recv_id to motor name for filtering responses
if hasattr(motor, "recv_id"):
self._recv_id_to_motor[motor.recv_id] = name
@property
def is_connected(self) -> bool:
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
if self.can_interface == "auto":
if self.port.startswith("/dev/"):
# Serial device (macOS/Windows)
self.can_interface = "slcan"
logger.info(f"Auto-detected slcan interface for port {self.port}")
else:
# Network interface (Linux)
self.can_interface = "socketcan"
logger.info(f"Auto-detected socketcan interface for port {self.port}")
# Connect to CAN bus
if self.can_interface == "socketcan":
# Linux SocketCAN with CAN FD support
if self.use_can_fd and self.data_bitrate is not None:
self.canbus = can.interface.Bus(
channel=self.port,
interface="socketcan",
bitrate=self.bitrate,
data_bitrate=self.data_bitrate,
fd=True
)
logger.info(f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})")
else:
self.canbus = can.interface.Bus(
channel=self.port,
interface="socketcan",
bitrate=self.bitrate
)
logger.info(f"Connected to {self.port} with CAN 2.0 (bitrate={self.bitrate})")
elif self.can_interface == "slcan":
# Serial Line CAN (macOS, Windows, or USB adapters)
# Note: SLCAN typically doesn't support CAN FD
self.canbus = can.interface.Bus(
channel=self.port,
interface="slcan",
bitrate=self.bitrate
)
logger.info(f"Connected to {self.port} with SLCAN (bitrate={self.bitrate})")
else:
# Generic interface (vector, pcan, etc.)
if self.use_can_fd and self.data_bitrate is not None:
self.canbus = can.interface.Bus(
channel=self.port,
interface=self.can_interface,
bitrate=self.bitrate,
data_bitrate=self.data_bitrate,
fd=True
)
else:
self.canbus = can.interface.Bus(
channel=self.port,
interface=self.can_interface,
bitrate=self.bitrate
)
self._is_connected = True
if handshake:
self._handshake()
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
except Exception as e:
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}")
def _handshake(self) -> None:
"""Verify all motors are present by refreshing their status."""
for motor_name in self.motors:
self._refresh_motor(motor_name)
time.sleep(0.01) # Small delay between motors
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected."
)
if disable_torque:
try:
self.disable_torque()
except Exception as e:
logger.warning(f"Failed to disable torque during disconnect: {e}")
if self.canbus:
self.canbus.shutdown()
self.canbus = None
self._is_connected = False
logger.debug(f"{self.__class__.__name__} disconnected.")
def configure_motors(self) -> None:
"""Configure all motors with default settings."""
# Damiao motors don't require much configuration in MIT mode
# Just ensure they're enabled
for motor in self.motors:
self._enable_motor(motor)
time.sleep(0.01)
def _enable_motor(self, motor: NameOrID) -> None:
"""Enable a single motor."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [CAN_CMD_ENABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
self._recv_motor_response(expected_recv_id=recv_id)
def _disable_motor(self, motor: NameOrID) -> None:
"""Disable a single motor."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [CAN_CMD_DISABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
self._recv_motor_response(expected_recv_id=recv_id)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
motors = self._get_motors_list(motors)
for motor in motors:
for _ in range(num_retry + 1):
try:
self._enable_motor(motor)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(0.01)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
motors = self._get_motors_list(motors)
for motor in motors:
for _ in range(num_retry + 1):
try:
self._disable_motor(motor)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(0.01)
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
"""
Context manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
Examples:
>>> with bus.torque_disabled():
... # Safe operations here with torque disabled
... pass
"""
self.disable_torque(motors)
try:
yield
finally:
self.enable_torque(motors)
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
"""Set current position as zero for selected motors."""
motors = self._get_motors_list(motors)
for motor in motors:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [CAN_CMD_SET_ZERO]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
self._recv_motor_response(expected_recv_id=recv_id)
time.sleep(0.01)
def _refresh_motor(self, motor: NameOrID) -> Optional[can.Message]:
"""Refresh motor status and return the response."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
def _recv_motor_response(self, expected_recv_id: Optional[int] = None, timeout: float = 0.001) -> Optional[can.Message]:
"""
Receive a response from a motor.
Args:
expected_recv_id: If provided, only return messages from this CAN ID
timeout: Timeout in seconds (default: 1ms for high-speed operation)
Returns:
CAN message if received, None otherwise
"""
try:
start_time = time.time()
messages_seen = []
while time.time() - start_time < timeout:
msg = self.canbus.recv(timeout=0.0001) # 100us timeout for fast polling
if msg:
messages_seen.append(f"0x{msg.arbitration_id:02X}")
# If no filter specified, return any message
if expected_recv_id is None:
return msg
# Otherwise, only return if it matches the expected recv_id
if msg.arbitration_id == expected_recv_id:
return msg
else:
logger.debug(f"Ignoring message from CAN ID 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}")
# Only log warnings if we're in debug mode to reduce overhead
if logger.isEnabledFor(logging.DEBUG):
if messages_seen:
logger.debug(f"Received {len(messages_seen)} message(s) from IDs {set(messages_seen)}, but expected 0x{expected_recv_id:02X}")
else:
logger.debug(f"No CAN messages received (expected from 0x{expected_recv_id:02X})")
except Exception as e:
logger.debug(f"Failed to receive CAN message: {e}")
return None
def _recv_all_responses(self, expected_recv_ids: list[int], timeout: float = 0.002) -> dict[int, can.Message]:
"""
Efficiently receive responses from multiple motors at once.
Uses the OpenArms pattern: collect all available messages within timeout.
Args:
expected_recv_ids: List of CAN IDs we expect responses from
timeout: Total timeout in seconds (default: 2ms)
Returns:
Dictionary mapping recv_id to CAN message
"""
responses = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
msg = self.canbus.recv(timeout=0.0002) # 200us poll timeout (increased from 100us for better reliability)
if msg and msg.arbitration_id in expected_set:
responses[msg.arbitration_id] = msg
if len(responses) == len(expected_recv_ids):
break # Got all responses, exit early
except Exception as e:
logger.debug(f"Error receiving responses: {e}")
return responses
def _mit_control(
self,
motor: NameOrID,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> None:
"""
Send MIT control command to a motor.
Args:
motor: Motor name or ID
kp: Position gain
kd: Velocity gain
position_degrees: Target position (degrees)
velocity_deg_per_sec: Target velocity (degrees/s)
torque: Target torque (N·m)
"""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
# Convert degrees to radians for motor control
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, 0, 500, 12)
kd_uint = self._float_to_uint(kd, 0, 5, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
self._recv_motor_response(expected_recv_id=recv_id)
def _mit_control_batch(
self,
commands: Dict[NameOrID, Tuple[float, float, float, float, float]],
) -> None:
"""
Send MIT control commands to multiple motors in batch (optimized).
Sends all commands first, then collects responses. Much faster than sequential.
Args:
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
"""
if not commands:
return
expected_recv_ids = []
# Step 1: Send all MIT control commands (no waiting)
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
# Convert degrees to radians
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, 0, 500, 12)
kd_uint = self._float_to_uint(kd, 0, 5, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
# Send command
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
# Track expected response
recv_id = self._get_motor_recv_id(motor)
expected_recv_ids.append(recv_id)
# Step 2: Collect all responses at once
self._recv_all_responses(expected_recv_ids, timeout=0.002)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
x = max(x_min, min(x_max, x)) # Clamp to range
span = x_max - x_min
data_norm = (x - x_min) / span
return int(data_norm * ((1 << bits) - 1))
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
"""Convert unsigned integer from CAN to float."""
span = x_max - x_min
data_norm = float(x) / ((1 << bits) - 1)
return data_norm * span + x_min
def _decode_motor_state(self, data: bytes, motor_type: MotorType) -> Tuple[float, float, float, int, int]:
"""
Decode motor state from CAN data.
Returns:
Tuple of (position_degrees, velocity_deg_per_sec, torque, temp_mos, temp_rotor)
"""
if len(data) < 8:
raise ValueError("Invalid motor state data")
# Extract encoded values
q_uint = (data[1] << 8) | data[2]
dq_uint = (data[3] << 4) | (data[4] >> 4)
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
t_mos = data[6]
t_rotor = data[7]
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Decode to physical values (radians)
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
# Convert to degrees
position_degrees = np.degrees(position_rad)
velocity_deg_per_sec = np.degrees(velocity_rad_per_sec)
return position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor
def read(
self,
data_name: str,
motor: str,
*,
normalize: bool = True,
num_retry: int = 0,
) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
if msg is None:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
raise ConnectionError(
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
)
motor_type = self._motor_types.get(motor, MotorType.DM4310)
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
# Return requested data (already in degrees for position/velocity)
if data_name == "Present_Position":
value = position_degrees
elif data_name == "Present_Velocity":
value = velocity_deg_per_sec
elif data_name == "Present_Torque":
value = torque
elif data_name == "Temperature_MOS":
value = t_mos
elif data_name == "Temperature_Rotor":
value = t_rotor
else:
raise ValueError(f"Unknown data_name: {data_name}")
# For Damiao, positions are always in degrees, no normalization needed
# We keep the normalize parameter for compatibility but don't use it
return value
def write(
self,
data_name: str,
motor: str,
value: Value,
*,
normalize: bool = True,
num_retry: int = 0,
) -> None:
"""Write a value to a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Value is expected to be in degrees for positions
if data_name == "Goal_Position":
# Use MIT control with position in degrees
self._mit_control(motor, 10.0, 0.5, value, 0, 0)
else:
raise ValueError(f"Writing {data_name} not supported in MIT mode")
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
) -> Dict[str, Value]:
"""
Read the same value from multiple motors simultaneously.
Uses batched operations: sends all refresh commands, then collects all responses.
This is MUCH faster than sequential reads (OpenArms pattern).
"""
motors = self._get_motors_list(motors)
result = {}
# Step 1: Send refresh commands to ALL motors first (no waiting)
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
# Step 2: Collect all responses at once (batch receive)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=0.01) # 10ms total timeout
# Step 3: Parse responses
for motor in motors:
try:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg is None:
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
result[motor] = 0.0
continue
motor_type = self._motor_types.get(motor, MotorType.DM4310)
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
# Return requested data
if data_name == "Present_Position":
value = position_degrees
elif data_name == "Present_Velocity":
value = velocity_deg_per_sec
elif data_name == "Present_Torque":
value = torque
elif data_name == "Temperature_MOS":
value = t_mos
elif data_name == "Temperature_Rotor":
value = t_rotor
else:
raise ValueError(f"Unknown data_name: {data_name}")
result[motor] = value
except Exception as e:
logger.warning(f"Failed to read {data_name} from {motor}: {e}")
result[motor] = 0.0
return result
def sync_read_all_states(
self,
motors: str | list[str] | None = None,
*,
num_retry: int = 0,
) -> Dict[str, Dict[str, Value]]:
"""
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
This is 3x faster than calling sync_read() three times separately.
Returns:
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
"""
motors = self._get_motors_list(motors)
result = {}
# Step 1: Send refresh commands to ALL motors first (with small delays to reduce bus congestion)
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
# Step 2: Collect all responses at once (batch receive)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
# Step 3: Parse responses and extract ALL state values
for motor in motors:
try:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg is None:
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
continue
motor_type = self._motor_types.get(motor, MotorType.DM4310)
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
# Return all state values in one dict
result[motor] = {
"position": position_degrees,
"velocity": velocity_deg_per_sec,
"torque": torque,
"temp_mos": t_mos,
"temp_rotor": t_rotor,
}
except Exception as e:
logger.warning(f"Failed to read state from {motor}: {e}")
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
return result
def sync_write(
self,
data_name: str,
values: Dict[str, Value],
*,
normalize: bool = True,
num_retry: int = 0,
) -> None:
"""
Write different values to multiple motors simultaneously. Positions are always in degrees.
Uses batched operations: sends all commands first, then collects responses (OpenArms pattern).
"""
if data_name == "Goal_Position":
# Step 1: Send all MIT control commands first (no waiting)
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
# Convert degrees to radians
position_rad = np.radians(value_degrees)
# Default gains for position control
kp, kd = 10.0, 0.5
# Get motor limits and encode parameters
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
kp_uint = self._float_to_uint(kp, 0, 500, 12)
kd_uint = self._float_to_uint(kd, 0, 5, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(0, -vmax, vmax, 12)
tau_uint = self._float_to_uint(0, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
# Step 2: Collect all responses at once
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in values.keys()]
self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
else:
# Fall back to individual writes for other data types
for motor, value in values.items():
self.write(data_name, motor, value, normalize=normalize, num_retry=num_retry)
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration data from motors."""
# Damiao motors don't store calibration internally
# Return existing calibration or empty dict
return self.calibration if self.calibration else {}
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration data to motors."""
# Damiao motors don't store calibration internally
# Just cache it in memory
if cache:
self.calibration = calibration_dict
def record_ranges_of_motion(
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
Move the joints by hand (with torque disabled) while the method streams live positions.
Press Enter to finish.
"""
if motors is None:
motors = list(self.motors.keys())
elif isinstance(motors, (str, int)):
motors = [motors]
# Disable torque for manual movement
self.disable_torque(motors)
time.sleep(0.1)
# Get initial positions (already in degrees)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
print("\nMove joints through their full range of motion. Press ENTER when done.")
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motors, normalize=False)
for motor in motors:
if motor in positions:
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
if display_values:
print("\n" + "=" * 50)
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
print("-" * 50)
for motor in motors:
if motor in positions:
print(f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}")
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 4)
time.sleep(0.05)
# Re-enable torque
self.enable_torque(motors)
# Validate ranges
for motor in motors:
if motor in mins and motor in maxes:
if abs(maxes[motor] - mins[motor]) < 5.0: # At least 5 degrees of range
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
return mins, maxes
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
"""Convert motor specification to list of motor names."""
if motors is None:
return list(self.motors.keys())
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors
else:
raise TypeError(f"Invalid motors type: {type(motors)}")
def _get_motor_id(self, motor: NameOrID) -> int:
"""Get CAN ID for a motor."""
if isinstance(motor, str):
if motor in self.motors:
return self.motors[motor].id
else:
raise ValueError(f"Unknown motor: {motor}")
else:
return motor
def _get_motor_name(self, motor: NameOrID) -> str:
"""Get motor name from name or ID."""
if isinstance(motor, str):
return motor
else:
for name, m in self.motors.items():
if m.id == motor:
return name
raise ValueError(f"Unknown motor ID: {motor}")
def _get_motor_recv_id(self, motor: NameOrID) -> Optional[int]:
"""Get motor recv_id from name or ID."""
motor_name = self._get_motor_name(motor)
motor_obj = self.motors.get(motor_name)
if motor_obj and hasattr(motor_obj, "recv_id"):
return motor_obj.recv_id
return None
@cached_property
def is_calibrated(self) -> bool:
"""Check if motors are calibrated."""
return bool(self.calibration)

View File

@@ -1,209 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
from typing import Dict, List, Tuple
# Motor type definitions
class MotorType(IntEnum):
DM3507 = 0
DM4310 = 1
DM4310_48V = 2
DM4340 = 3
DM4340_48V = 4
DM6006 = 5
DM8006 = 6
DM8009 = 7
DM10010L = 8
DM10010 = 9
DMH3510 = 10
DMH6215 = 11
DMG6220 = 12
# Control modes
class ControlMode(IntEnum):
MIT = 1
POS_VEL = 2
VEL = 3
TORQUE_POS = 4
# Motor variable IDs (RID)
class MotorVariable(IntEnum):
UV_VALUE = 0
KT_VALUE = 1
OT_VALUE = 2
OC_VALUE = 3
ACC = 4
DEC = 5
MAX_SPD = 6
MST_ID = 7
ESC_ID = 8
TIMEOUT = 9
CTRL_MODE = 10
DAMP = 11
INERTIA = 12
HW_VER = 13
SW_VER = 14
SN = 15
NPP = 16
RS = 17
LS = 18
FLUX = 19
GR = 20
PMAX = 21
VMAX = 22
TMAX = 23
I_BW = 24
KP_ASR = 25
KI_ASR = 26
KP_APR = 27
KI_APR = 28
OV_VALUE = 29
GREF = 30
DETA = 31
V_BW = 32
IQ_C1 = 33
VL_C1 = 34
CAN_BR = 35
SUB_VER = 36
U_OFF = 50
V_OFF = 51
K1 = 52
K2 = 53
M_OFF = 54
DIR = 55
P_M = 80
XOUT = 81
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS = {
MotorType.DM3507: (12.5, 30, 10),
MotorType.DM4310: (12.5, 30, 10),
MotorType.DM4310_48V: (12.5, 50, 10),
MotorType.DM4340: (12.5, 8, 28),
MotorType.DM4340_48V: (12.5, 10, 28),
MotorType.DM6006: (12.5, 45, 20),
MotorType.DM8006: (12.5, 45, 40),
MotorType.DM8009: (12.5, 45, 54),
MotorType.DM10010L: (12.5, 25, 200),
MotorType.DM10010: (12.5, 20, 200),
MotorType.DMH3510: (12.5, 280, 1),
MotorType.DMH6215: (12.5, 45, 10),
MotorType.DMG6220: (12.5, 45, 10),
}
# Motor model names
MODEL_NAMES = {
MotorType.DM3507: "dm3507",
MotorType.DM4310: "dm4310",
MotorType.DM4310_48V: "dm4310_48v",
MotorType.DM4340: "dm4340",
MotorType.DM4340_48V: "dm4340_48v",
MotorType.DM6006: "dm6006",
MotorType.DM8006: "dm8006",
MotorType.DM8009: "dm8009",
MotorType.DM10010L: "dm10010l",
MotorType.DM10010: "dm10010",
MotorType.DMH3510: "dmh3510",
MotorType.DMH6215: "dmh6215",
MotorType.DMG6220: "dmg6220",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"dm3507": 65536,
"dm4310": 65536,
"dm4310_48v": 65536,
"dm4340": 65536,
"dm4340_48v": 65536,
"dm6006": 65536,
"dm8006": 65536,
"dm8009": 65536,
"dm10010l": 65536,
"dm10010": 65536,
"dmh3510": 65536,
"dmh6215": 65536,
"dmg6220": 65536,
}
# CAN baudrates supported by Damiao motors
AVAILABLE_BAUDRATES = [
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
1000000, # 4: 1 mbps (default for OpenArms)
2000000, # 5: 2 mbps
2500000, # 6: 2.5 mbps
3200000, # 7: 3.2 mbps
4000000, # 8: 4 mbps
5000000, # 9: 5 mbps
]
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 1000
# Data that should be normalized
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
# OpenArms specific configurations
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms has 7 DOF per arm (14 total for dual arm)
OPENARMS_ARM_MOTOR_IDS = {
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
}
OPENARMS_GRIPPER_MOTOR_IDS = {
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
}
# Default motor types for OpenArms
OPENARMS_DEFAULT_MOTOR_TYPES = {
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
}
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_REFRESH = 0xCC
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF

View File

@@ -24,7 +24,7 @@ from enum import Enum
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
@@ -60,7 +60,7 @@ class OperatingMode(Enum):
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
# conveyor systems or a system that requires an additional reduction gear. Note that Max Position
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
EXTENDED_POSITION = 4
@@ -100,7 +100,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
return data
class DynamixelMotorsBus(SerialMotorsBus):
class DynamixelMotorsBus(MotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:

View File

@@ -19,7 +19,7 @@ from pprint import pformat
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
@@ -96,7 +96,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(SerialMotorsBus):
class FeetechMotorsBus(MotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
@@ -165,7 +165,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _handshake(self) -> None:
self._assert_motors_exist()
#self._assert_same_firmware()
self._assert_same_firmware()
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
if self.protocol_version == 0:

View File

@@ -206,12 +206,8 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11,
"Goal_Position": 15,
"Goal_Velocity": 15,
"Goal_Speed": 15,
"Present_Position": 15,
"Present_Velocity": 15,
"Present_Speed": 15,
}
MODEL_ENCODING_TABLE = {

View File

@@ -19,8 +19,6 @@
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
from __future__ import annotations
import abc
import logging
from contextlib import contextmanager
@@ -43,92 +41,6 @@ Value: TypeAlias = int | float
logger = logging.getLogger(__name__)
class MotorsBusBase(abc.ABC):
"""
Base class for all motor bus implementations.
This is a minimal interface that all motor buses must implement, regardless of their
communication protocol (serial, CAN, etc.).
"""
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
@abc.abstractmethod
def connect(self, handshake: bool = True) -> None:
"""Establish connection to the motors."""
pass
@abc.abstractmethod
def disconnect(self, disable_torque: bool = True) -> None:
"""Disconnect from the motors."""
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if connected to the motors."""
pass
@abc.abstractmethod
def read(self, data_name: str, motor: str, *, normalize: bool = True, num_retry: int = 0) -> Value:
"""Read a value from a single motor."""
pass
@abc.abstractmethod
def write(
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
) -> None:
"""Write a value to a single motor."""
pass
@abc.abstractmethod
def sync_read(
self, data_name: str, motors: str | list[str] | None = None, *, normalize: bool = True
) -> dict[str, Value]:
"""Read a value from multiple motors."""
pass
@abc.abstractmethod
def sync_write(
self,
data_name: str,
values: Value | dict[str, Value],
motors: str | list[str] | None = None,
*,
normalize: bool = True,
) -> None:
"""Write values to multiple motors."""
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
pass
@abc.abstractmethod
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
pass
@abc.abstractmethod
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration parameters from the motors."""
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration parameters to the motors."""
pass
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
@@ -291,15 +203,15 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ...
class SerialMotorsBus(MotorsBusBase):
class MotorsBus(abc.ABC):
"""
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
A MotorsBus allows to efficiently read and write to the attached motors.
It represents several motors daisy-chained together and connected through a serial port.
There are currently two implementations of this class:
There are currently two implementations of this abstract class:
- DynamixelMotorsBus
- FeetechMotorsBus
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
Note: This class may evolve in the future should we add support for other types of bus.
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
@@ -1300,7 +1212,3 @@ class SerialMotorsBus(MotorsBusBase):
for id_, value in ids_values.items():
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
# Backward compatibility alias
MotorsBus = SerialMotorsBus

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import math
from dataclasses import asdict, dataclass
from pathlib import Path
@@ -80,11 +79,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by Physical Intelligence to train Pi0.
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
This ensures the learning rate schedule completes properly even with shorter training runs.
"""
"""Used by Physical Intelligence to train Pi0"""
num_warmup_steps: int
num_decay_steps: int
@@ -92,39 +87,23 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
decay_lr: float
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps
actual_warmup_steps = self.num_warmup_steps
actual_decay_steps = self.num_decay_steps
if num_training_steps < self.num_decay_steps:
# Calculate scaling factor to fit the schedule into the available training steps
scale_factor = num_training_steps / self.num_decay_steps
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
actual_decay_steps = num_training_steps
logging.info(
f"Auto-scaling LR scheduler: "
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
f"Scaling warmup: {self.num_warmup_steps}{actual_warmup_steps}, "
f"decay: {self.num_decay_steps}{actual_decay_steps} "
f"(scale factor: {scale_factor:.3f})"
)
del num_training_steps
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (actual_warmup_steps + 1)
frac = 1 - current_step / actual_warmup_steps
return (1 / (actual_warmup_steps + 1) - 1) * frac + 1
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
def cosine_decay_schedule(current_step):
step = min(current_step, actual_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps))
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
if current_step < actual_warmup_steps:
if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)
return cosine_decay_schedule(current_step)

View File

@@ -14,7 +14,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
@@ -30,5 +29,4 @@ __all__ = [
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
]

View File

@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder).
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""

View File

@@ -30,7 +30,6 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -102,10 +101,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -147,8 +142,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -206,27 +199,6 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
# GROOT handles normalization in groot_pack_inputs_v3 step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features["action"].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -321,14 +293,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
processors = make_groot_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
@@ -339,7 +303,6 @@ def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
rename_map: dict[str, str] | None = None,
) -> PreTrainedPolicy:
"""
Instantiate a policy model.
@@ -356,8 +319,6 @@ def make_policy(
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
rename_map: Optional mapping of dataset or environment feature keys to match
expected policy feature names (e.g., `"left"` → `"camera1"`).
Returns:
An instantiated and device-placed policy model.
@@ -399,10 +360,8 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
if not cfg.output_features:
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
@@ -419,21 +378,4 @@ def make_policy(
# policy = torch.compile(policy, mode="reduce-overhead")
if not rename_map:
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
provided_features = set(features.keys())
if expected_features and provided_features != expected_features:
missing = expected_features - provided_features
extra = provided_features - expected_features
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
raise ValueError(
f"Feature mismatch between dataset/environment and policy config.\n"
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
f"Please ensure your dataset and policy use consistent feature names.\n"
f"If your dataset uses different observation keys (e.g., cameras named differently), "
f"use the `--rename_map` argument, for example:\n"
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
f'"observation.images.top": "observation.images.camera2"}}\''
)
return policy

View File

@@ -1 +0,0 @@
../../../../docs/source/policy_groot_README.md

View File

@@ -1,21 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Nvidia and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_groot import GrootConfig
from .modeling_groot import GrootPolicy
from .processor_groot import make_groot_pre_post_processors
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]

View File

@@ -1,14 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,54 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
def swish(x):
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
"""
Produces a sinusoidal encoding of shape (B, T, w)
given timesteps of shape (B, T).
"""
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps):
# timesteps: shape (B, T)
# We'll compute sin/cos frequencies across dim T
timesteps = timesteps.float() # ensure float
b, t = timesteps.shape
device = timesteps.device
half_dim = self.embedding_dim // 2
# typical log space frequencies for sinusoidal encoding
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
# Expand timesteps to (B, T, 1) then multiply
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc

View File

@@ -1,370 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import (
SinusoidalPositionalEmbedding,
TimestepEmbedding,
Timesteps,
)
from torch import nn
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(
self,
embedding_dim: int,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: torch.Tensor | None = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: int | None = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: str | None = None,
num_positional_embeddings: int | None = None,
ff_inner_dim: int | None = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.norm_type = norm_type
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
if final_dropout:
self.final_dropout = nn.Dropout(dropout)
else:
self.final_dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
temb: torch.LongTensor | None = None,
) -> torch.Tensor:
# 0. Self-Attention
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, temb)
else:
norm_hidden_states = self.norm1(hidden_states)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
output_dim: int = 26,
num_layers: int = 12,
dropout: float = 0.1,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: int | None = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
max_num_positional_embeddings: int = 512,
compute_dtype=torch.float32,
final_dropout: bool = True,
positional_embeddings: str | None = "sinusoidal",
interleave_self_attention=False,
cross_attention_dim: int | None = None,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
# Timestep encoder
self.timestep_encoder = TimestepEncoder(
embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype
)
all_blocks = []
for idx in range(self.config.num_layers):
use_self_attn = idx % 2 == 1 and interleave_self_attention
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
all_blocks += [
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
positional_embeddings=positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=final_dropout,
cross_attention_dim=curr_cross_attention_dim,
)
]
self.transformer_blocks = nn.ModuleList(all_blocks)
# Output blocks
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
print(
"Total number of DiT parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
return_all_hidden_states: bool = False,
):
# Encode timesteps
temb = self.timestep_encoder(timestep)
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1 and self.config.interleave_self_attention:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
temb=temb,
)
all_hidden_states.append(hidden_states)
# Output processing
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
else:
return self.proj_out_2(hidden_states)
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
output_dim: int = 26,
num_layers: int = 12,
dropout: float = 0.1,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: int | None = 1000,
upcast_attention: bool = False,
max_num_positional_embeddings: int = 512,
compute_dtype=torch.float32,
final_dropout: bool = True,
positional_embeddings: str | None = "sinusoidal",
interleave_self_attention=False,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
positional_embeddings=positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=final_dropout,
)
for _ in range(self.config.num_layers)
]
)
print(
"Total number of SelfAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
return_all_hidden_states: bool = False,
):
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for _idx, block in enumerate(self.transformer_blocks):
hidden_states = block(hidden_states)
all_hidden_states.append(hidden_states)
if return_all_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states

View File

@@ -1,406 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
else:
PretrainedConfig = object
BatchFeature = None
from lerobot.policies.groot.action_head.action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
from .cross_attention_dit import DiT, SelfAttentionTransformer
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size, num_embodiments):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
b, t, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == b:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, t)
else:
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
@dataclass
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
backbone_embedding_dim: int = field(
default=1536, metadata={"help": "Backbone embedding channel dimension."}
)
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
num_timestep_buckets: int = field(
default=1000, metadata={"help": "Number of timestep discretization buckets."}
)
num_inference_timesteps: int = field(
default=None,
metadata={"help": "Number of inference steps for noise diffusion."},
)
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
tune_diffusion_model: bool = field(
default=True, metadata={"help": "Whether to tune the diffusion model."}
)
load_pretrained_det_decode_layer_path: str = field(
default=None, metadata={"help": "Path to pretrained detection model."}
)
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
freeze_decode_layer: bool = field(default=False)
expand_batch: int = field(default=None)
use_vlln: bool = field(default=True)
vl_self_attention_cfg: dict = field(default=None)
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
class FlowmatchingActionHead(nn.Module):
config_class = FlowmatchingActionHeadConfig
supports_gradient_checkpointing = True
def __init__(
self,
config: FlowmatchingActionHeadConfig,
):
super().__init__()
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
self.model = DiT(**config.diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=config.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
self.vl_self_attention = (
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
)
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
for p in self.parameters():
p.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
print(f"Tune action head projector: {self.tune_projector}")
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_projector and not tune_diffusion_model:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Action head trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No action head trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = backbone_output["backbone_features"]
backbone_features = self.vlln(backbone_features)
backbone_features = self.vl_self_attention(backbone_features)
backbone_output["backbone_features"] = backbone_features
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
# Set frozen modules to eval
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
if self.config.expand_batch is not None:
for k, v in backbone_output.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
backbone_output[k] = expanded
for k, v in action_input.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
action_input[k] = expanded
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
device = vl_embs.device
# Get embodiment ID.
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Embed noised action trajectory.
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
vl_attn_mask = backbone_output.backbone_attention_mask
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=vl_attn_mask,
timestep=t_discretized,
return_all_hidden_states=False, # NOTE (YL): not using flare now
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
# Slice out only the action portion of pred and target.
action_mask = action_input.action_mask
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = loss.sum() / action_mask.sum()
output_dict = {
"loss": loss,
}
return BatchFeature(data=output_dict)
@torch.no_grad()
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Set initial actions as the sampled noise.
batch_size = vl_embs.shape[0]
device = vl_embs.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=vl_embs.dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# Run denoising steps.
for t in range(num_steps):
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# Embed noised action trajectory.
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
# Run model forward.
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -self.action_horizon :]
# Update actions using euler integration.
actions = actions + dt * pred_velocity
return BatchFeature(data={"action_pred": actions})
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@@ -1,201 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("groot")
@dataclass
class GrootConfig(PreTrainedConfig):
"""Configuration for Groot policy wrapper."""
# Basic policy settings
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
# Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 64
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 32
# Normalization (start with identity, adjust as needed)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Groot-specific model parameters (from groot_finetune_script.py)
# Path or HuggingFace model ID for the base Groot model
base_model_path: str = "nvidia/GR00T-N1.5-3B"
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
# Fine-tuning control arguments
# Whether to fine-tune the llm backbone
tune_llm: bool = False
# Whether to fine-tune the vision tower
tune_visual: bool = False
# Whether to fine-tune the projector
tune_projector: bool = True
# Whether to fine-tune the diffusion model
tune_diffusion_model: bool = True
# LoRA parameters (from groot_finetune_script.py)
# Rank for the LORA model. If 0, no LORA will be used.
lora_rank: int = 0
# Alpha value for the LORA model
lora_alpha: int = 16
# Dropout rate for the LORA model
lora_dropout: float = 0.1
# Whether to use the full model for LORA
lora_full_model: bool = False
# Training parameters (matching groot_finetune_script.py)
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
warmup_ratio: float = 0.05
use_bf16: bool = True
# Dataset parameters
# Video backend to use for training ('decord' or 'torchvision_av')
video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t"
save_steps: int = 1000
max_steps: int = 10000
batch_size: int = 32
dataloader_num_workers: int = 8
report_to: str = "wandb"
resume: bool = False
def __post_init__(self):
super().__post_init__()
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
)
# groot_repo_path is now optional since we ported the components
# No validation needed
def validate_features(self) -> None:
"""Validate and set up input/output features for Groot."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"Groot policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
self.input_features["observation.state"] = state_feature
else:
state_shape = self.input_features["observation.state"].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
f"Either reduce state dimension or increase max_state_dim in config."
)
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
self.output_features["action"] = action_feature
else:
action_shape = self.output_features["action"].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
f"Either reduce action dimension or increase max_action_dim in config."
)
def get_optimizer_preset(self) -> AdamWConfig:
"""Return optimizer configuration."""
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Return scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
num_decay_steps=10000, # Adjust based on training steps
peak_lr=self.optimizer_lr,
decay_lr=self.optimizer_lr * 0.1,
)
@property
def observation_delta_indices(self) -> None:
"""Return indices for delta observations (None for Groot)."""
return None
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
return list(range(min(self.chunk_size, 16)))
@property
def reward_delta_indices(self) -> None:
"""Return indices for delta rewards (None for Groot)."""
return None

View File

@@ -1,135 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Eagle25VLConfig(PretrainedConfig):
model_type = "eagle_2_5_vl"
is_composition = True
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
def __init__(
self,
vision_config=None,
text_config=None,
use_backbone_lora=0,
use_llm_lora=0,
pad2square=False,
select_layer=-4,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
loss_version="v1",
min_dynamic_tiles=1,
max_dynamic_tiles=6,
mlp_checkpoint=False,
initializer_range=0.02,
_attn_implementation="flash_attention_2",
_attn_implementation_autoset=False,
llm_config=None,
image_token_index=None,
use_pixel_shuffle=True,
mlp_connector_layers=2,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {"model_type": "siglip_vision_model"}
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
if text_config is None:
text_config = {"architectures": ["Qwen2ForCausalLM"]}
logger.info(
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
if vision_config["model_type"] == "siglip_vision_model":
self.vision_config = SiglipVisionConfig(**vision_config)
else:
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
if text_config["architectures"][0] == "LlamaForCausalLM":
self.text_config = LlamaConfig(**text_config)
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
self.text_config = Qwen2Config(**text_config)
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
self.text_config = Qwen3Config(**text_config)
else:
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.mlp_checkpoint = mlp_checkpoint
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.loss_version = loss_version
self.initializer_range = initializer_range
self.min_dynamic_tiles = min_dynamic_tiles
self.max_dynamic_tiles = max_dynamic_tiles
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self._attn_implementation = _attn_implementation
self._attn_implementation_autoset = _attn_implementation_autoset
self.image_token_index = image_token_index
self.use_pixel_shuffle = use_pixel_shuffle
self.mlp_connector_layers = mlp_connector_layers
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
output["use_backbone_lora"] = self.use_backbone_lora
output["use_llm_lora"] = self.use_llm_lora
output["pad2square"] = self.pad2square
output["select_layer"] = self.select_layer
output["force_image_size"] = self.force_image_size
output["downsample_ratio"] = self.downsample_ratio
output["template"] = self.template
output["dynamic_image_size"] = self.dynamic_image_size
output["use_thumbnail"] = self.use_thumbnail
output["min_dynamic_tiles"] = self.min_dynamic_tiles
output["max_dynamic_tiles"] = self.max_dynamic_tiles
output["tie_word_embeddings"] = self.tie_word_embeddings
output["_attn_implementation"] = self._attn_implementation
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
output["use_pixel_shuffle"] = self.use_pixel_shuffle
output["mlp_connector_layers"] = self.mlp_connector_layers
return output

View File

@@ -1,504 +0,0 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from typing import Optional
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
make_flat_list_of_images,
validate_kwargs,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_v2_available,
)
from transformers.video_utils import VideoInput
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F # noqa: N812
from transformers.image_utils import pil_torch_interpolation_mapping
else:
from torchvision.transforms import functional as F # noqa: N812
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
"""Crop the given numpy array.
Args:
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
left (int): The left coordinate of the crop box.
top (int): The top coordinate of the crop box.
right (int): The right coordinate of the crop box.
bottom (int): The bottom coordinate of the crop box.
Returns:
torch.Tensor: Cropped image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
if img.ndim not in [2, 3]:
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
img_height = img.shape[1]
img_width = img.shape[2]
if top < 0 or left < 0 or bottom > img_height or right > img_width:
raise ValueError("Crop coordinates out of bounds")
if top >= bottom or left >= right:
raise ValueError("Invalid crop coordinates")
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
pad_during_tiling: bool | None
do_pad: bool | None
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method. Not used for processing videos.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 448, "width": 448}
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
max_dynamic_tiles = 12
min_dynamic_tiles = 1
use_thumbnail = True
pad_during_tiling = False
valid_kwargs = Eagle25VLFastImageProcessorKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
"""
max_dynamic_tiles (`int`, *optional*):
The maximum number of dynamic tiles to use for processing high resolution images.
min_dynamic_tiles (`int`, *optional*):
The minimum number of dynamic tiles to use for processing high resolution images.
use_thumbnail (`bool`, *optional*):
Whether to use a thumbnail for processing high resolution images.
pad_during_tiling (`bool`, *optional*):
Whether to pad the image during tiling.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
# NOTE(YL): we will overload the preprocess method to add the image_flags
# def preprocess(
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
# ) -> BatchFeature:
# return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
expected_ndims: int = 3,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
expected_ndims (`int`, *optional*, defaults to 3):
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _resize_for_patching(
self,
image: "torch.Tensor",
target_resolution: tuple,
interpolation: "F.InterpolationMode",
input_data_format: ChannelDimension,
) -> "torch.Tensor":
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image ("torch.Tensor"):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
"torch.Tensor": The resized and padded image.
"""
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
return resized_image
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
"""
previous version mainly focus on ratio.
We also consider area ratio here.
"""
best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
"""
new area > 60% of original image area is enough.
"""
factor_based_on_area_n_ratio = min(
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
return padded_image
def _get_image_patches(
self,
image: "torch.Tensor",
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: "F.InterpolationMode",
pad_during_tiling: bool,
) -> list["torch.Tensor"]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
# calculate the target width and height
target_width = tile_size * target_aspect_ratio[0]
target_height = tile_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if pad_during_tiling:
resized_image = self._resize_for_patching(
image,
(target_height, target_width),
interpolation=interpolation,
input_data_format=ChannelDimension.FIRST,
)
padded_image = self._pad_for_patching(
resized_image,
(target_height, target_width),
input_data_format=ChannelDimension.FIRST,
)
image_used_to_split = padded_image
else:
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
processed_tiles = []
for i in range(blocks):
box = (
(i % (target_width // tile_size)) * tile_size,
(i // (target_width // tile_size)) * tile_size,
((i % (target_width // tile_size)) + 1) * tile_size,
((i // (target_width // tile_size)) + 1) * tile_size,
)
# split the image
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
processed_tiles.append(split_img)
assert len(processed_tiles) == blocks
if use_thumbnail and len(processed_tiles) != 1:
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
processed_tiles.append(thumbnail_img)
return processed_tiles
def _pad_for_batching(
self,
pixel_values: list["torch.Tensor"],
) -> list["torch.Tensor"]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
for image in pixel_values
]
return pixel_values
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
do_pad: bool,
return_tensors: str | TensorType | None,
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
) -> BatchFeature:
processed_images = []
image_sizes = []
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
tile_size = crop_size.height
elif size and size.height:
tile_size = size.height
else:
tile_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
min_num=min_dynamic_tiles,
max_num=max_dynamic_tiles,
size=size_tuple,
tile_size=tile_size,
use_thumbnail=use_thumbnail,
interpolation=interpolation,
pad_during_tiling=pad_during_tiling,
)
# Group images by size for batched processing
processed_image_patches_grouped = {}
# Added for transformers >=4.53.0 compatibility
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches,
disable_grouping=disable_grouping,
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(
processed_image_patches_grouped, grouped_image_patches_index
)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes},
tensor_type=return_tensors,
)
def preprocess(
self,
images: ImageInput,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
if images is not None:
images = self._prepare_image_like_inputs(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
if videos is not None:
videos = self._prepare_image_like_inputs(
images=videos,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
# Added for transformers >=4.53.0 compatibility
resample = kwargs.pop("resample", self.resample)
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, PILImageResampling | int)
else resample
)
# Filter kwargs to only include those accepted by _preprocess
valid_preprocess_kwargs = {
"do_resize",
"size",
"max_dynamic_tiles",
"min_dynamic_tiles",
"use_thumbnail",
"pad_during_tiling",
"interpolation",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"return_tensors",
"pad_size",
"disable_grouping",
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
if images is not None:
return self._preprocess(images, **filtered_kwargs)
elif videos is not None:
return self._preprocess(videos, **filtered_kwargs)
__all__ = ["Eagle25VLImageProcessorFast"]

View File

@@ -1,395 +0,0 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import inspect
import torch
import torch.utils.checkpoint as cp
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers.utils import add_start_docstrings, logging
from .configuration_eagle2_5_vl import Eagle25VLConfig
logger = logging.get_logger(__name__)
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
EAGLE2_5_VL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Eagle25VLConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
EAGLE2_5_VL_START_DOCSTRING,
)
class Eagle25VLPreTrainedModel(PreTrainedModel):
config_class = Eagle25VLConfig
base_model_prefix = "model"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_no_split_modules = [
"Qwen2DecoderLayer",
"LlamaDecoderLayer",
"Siglip2EncoderLayer",
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear | nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
config_class = Eagle25VLConfig
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
super().__init__(config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
if config.use_pixel_shuffle:
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
else:
self.num_image_token = int((image_size // patch_size) ** 2)
self.select_layer = config.select_layer
self.downsample_ratio = config.downsample_ratio
self.loss_version = config.loss_version
self.mlp_checkpoint = config.mlp_checkpoint
self.use_pixel_shuffle = config.use_pixel_shuffle
self.mlp_connector_layers = config.mlp_connector_layers
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
if vision_model is not None:
self.vision_model = vision_model
else:
if config.vision_config.model_type == "siglip_vision_model":
config.vision_config._attn_implementation = "flash_attention_2"
self.vision_model = SiglipVisionModel(config.vision_config)
else:
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
if language_model is not None:
self.language_model = language_model
else:
if config.text_config.architectures[0] == "LlamaForCausalLM":
self.language_model = LlamaForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
raise NotImplementedError("Phi3 is not implemented.")
# self.language_model = Phi3ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
assert config.text_config._attn_implementation == "flash_attention_2", (
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
)
self.language_model = Qwen2ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(config.text_config)
else:
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
if config.mlp_connector_layers == 2:
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size, llm_hidden_size),
)
else:
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
self.image_token_index = config.image_token_index
self.neftune_alpha = None
if config.use_backbone_lora:
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
self.use_llm_lora = config.use_llm_lora
if config.use_llm_lora:
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
self.check_forward_kwargs()
def check_forward_kwargs(self):
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
# has special handling for functions with **kwargs parameters that would affect
# how our model is processed during training and inference.
forward_params = inspect.signature(self.forward).parameters
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.out_proj",
"mlp.fc1",
"mlp.fc2",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
self.vision_model.print_trainable_parameters()
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
task_type="CAUSAL_LM",
)
self.language_model = get_peft_model(self.language_model, lora_config)
self.language_model.enable_input_require_grads()
self.language_model.print_trainable_parameters()
self.use_llm_lora = True
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
image_flags: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
num_tiles_list: list[torch.Tensor] | None = None,
) -> tuple | CausalLMOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vit_embeds = self.extract_feature(pixel_values)
if image_flags is not None:
image_flags = image_flags.view(-1)
vit_embeds = vit_embeds[image_flags == 1]
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.image_token_index
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, c)
print(
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
f"vit_embeds.shape={vit_embeds.shape}"
)
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(b, n, c)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
)
if hasattr(vit_embeds, "last_hidden_state"):
vit_embeds = vit_embeds.last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
).hidden_states[self.select_layer]
if self.use_pixel_shuffle:
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
if self.mlp_checkpoint and vit_embeds.requires_grad:
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
else:
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor | None = None,
input_ids: torch.FloatTensor | None = None,
attention_mask: torch.LongTensor | None = None,
visual_features: torch.FloatTensor | None = None,
generation_config: GenerationConfig | None = None,
output_hidden_states: bool | None = None,
image_sizes: list[tuple[int, int]] | None = None,
**generate_kwargs,
) -> torch.LongTensor:
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.config.image_token_index
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
input_embeds = input_embeds.reshape(b, n, c)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
if "use_cache" not in generate_kwargs:
generate_kwargs["use_cache"] = True
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
**generate_kwargs,
)
return outputs
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()

View File

@@ -1,518 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Eagle25VL.
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
"""
import base64
import os
import re
from io import BytesIO
import requests
import torch
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
logger = logging.get_logger(__name__)
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 256
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == "RGBA":
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
image = ele["image"] if "image" in ele else ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True, timeout=10)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = to_rgb(image_obj)
if "scale_factor" in ele:
scale_factor = ele["scale_factor"]
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
return image
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"videos_kwargs": {"max_dynamic_tiles": 1},
}
class Eagle25VLProcessor(ProcessorMixin):
r"""
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
Args:
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
num_image_tokens (`int`, *optional*):
Number of image tokens for one imagethat will be returned by vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Should be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"<image>"`):
Special token used to denote image location.
video_token (`str`, *optional*, defaults to `"<video>"`):
Special token used to denote video location.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"num_image_tokens",
"vision_feature_select_strategy",
"image_token",
"video_token",
"images_kwargs",
"videos_kwargs",
"text_kwargs",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
vision_feature_select_strategy=None,
chat_template=None,
image_token="<IMG_CONTEXT>", # nosec: B107
video_token="<IMG_CONTEXT>", # nosec: B107
tokens_per_tile=256,
image_placeholder="image",
video_placeholder="video",
image_start_token="<img>",
image_end_token="</img>",
**kwargs,
):
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.image_placeholder = image_placeholder
self.video_placeholder = video_placeholder
self.tokens_per_tile = tokens_per_tile
self.image_start_token = image_start_token
self.image_end_token = image_end_token
if "auto_map" in kwargs:
self.auto_map = kwargs["auto_map"]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def replace_media_placeholder(
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
):
num_of_images_in_this_sample = 0
num_of_videos_in_this_sample = 0
# Regular expression pattern to match formats like <image-1> or <video-2>
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
unified_frame_list = []
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
# )
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
# )
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
# "use_thumbnail", self.image_processor.use_thumbnail
# )
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
)
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
)
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
"use_thumbnail", self.image_processor.use_thumbnail
)
tile_size = self.image_processor.size.get("height", 448)
# Function to replace tags in a single text
def replace_in_text(text):
# repl callback function for each match replacement operation
def repl(match):
nonlocal unified_frame_list
nonlocal num_of_images_in_this_sample
nonlocal num_of_videos_in_this_sample
media_type = match.group(1) # 'image' or 'video'
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
# Select the corresponding path based on media type
idx_mapper = {
0: "first",
1: "second",
2: "third",
3: "fourth",
4: "fifth",
5: "sixth",
6: "seventh",
7: "eighth",
8: "ninth",
9: "tenth",
}
if media_type == "image":
image_inputs = self.image_processor(
images=[image_list[idx_in_list]],
videos=None,
**output_kwargs["images_kwargs"],
)
num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs)
num_of_images_in_this_sample += 1
elif media_type == "video":
video_inputs = self.image_processor(
images=None,
videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"],
)
num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list:
frame_timestamps = timestamps_list[idx_in_list]
else:
frame_timestamps = None
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
num_of_tiles_each_frame = [
self.get_number_tiles_based_on_image_size(
image_size,
video_min_dynamic_tiles,
video_max_dynamic_tiles,
video_use_thumbnail,
tile_size,
)
for image_size in image_sizes
]
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
)
if frame_timestamps is not None:
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
)
special_placeholder = [
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
else:
special_placeholder = [
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
if sampled_fps is not None:
special_placeholder = (
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
+ "".join(special_placeholder)
)
else:
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
special_placeholder
)
unified_frame_list.append(video_inputs)
num_of_videos_in_this_sample += 1
else:
raise ValueError(f"Unknown media type: {media_type}")
return special_placeholder
return pattern.sub(repl, text)
text = replace_in_text(text)
if len(unified_frame_list) > 0:
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
else:
pixel_values = None
image_sizes = None
return (
text,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
)
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
audio=None,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Eagle25VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text_list = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
elif isinstance(text, list) and isinstance(text[0], str):
text_list = text
if images is None:
images = []
if videos is None:
videos = []
pixel_values_list = []
image_sizes_list = []
new_sample_list = []
image_start_idx = 0
video_start_idx = 0
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
for sample in text_list:
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
(
sample,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
) = self.replace_media_placeholder(
sample,
images[image_start_idx:],
videos[video_start_idx:],
timestamps_list,
fps_list,
**output_kwargs,
)
new_sample_list.append(sample)
if pixel_values is not None:
pixel_values_list.append(pixel_values)
image_sizes_list.append(image_sizes)
image_start_idx += num_of_images_in_this_sample
video_start_idx += num_of_videos_in_this_sample
if len(pixel_values_list) > 0:
image_inputs = {
"pixel_values": torch.cat(pixel_values_list),
"image_sizes": torch.cat(image_sizes_list),
}
else:
image_inputs = {}
video_inputs = {}
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
def get_number_tiles_based_on_image_size(
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
) -> int:
"""
Get the number of tiles based on the image size.
"""
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and tiles_num > 1:
tiles_num += 1
return tiles_num
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# override to save video-config in a separate config file
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
outputs = super().save_pretrained(save_directory, **kwargs)
return outputs
# override to load video-config from a separate config file
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
return processor
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def process_vision_info(
self,
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
vision_infos = self.extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
video_timestamps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return (
image_inputs,
video_inputs,
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
)
return image_inputs, video_inputs
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
__all__ = ["Eagle25VLProcessor"]

Some files were not shown because too many files have changed in this diff Show More