Compare commits

..

3 Commits

Author SHA1 Message Date
Steven Palma
224be5be9a Merge branch 'main' into feat/add_macos_ci 2025-10-14 18:52:04 +02:00
Steven Palma
67269e33a5 ci: add more env flags 2025-10-08 17:14:20 +02:00
Steven Palma
66936f278f feat(ci): add macos runner testing 2025-10-08 14:58:55 +02:00
67 changed files with 1608 additions and 1075 deletions

View File

@@ -57,7 +57,11 @@ jobs:
# It runs everytime we commit to a PR or push to main # It runs everytime we commit to a PR or push to main
fast-pytest-tests: fast-pytest-tests:
name: Fast Pytest Tests name: Fast Pytest Tests
runs-on: ubuntu-latest runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
env: env:
MUJOCO_GL: egl MUJOCO_GL: egl
steps: steps:
@@ -67,12 +71,21 @@ jobs:
lfs: true lfs: true
# TODO(Steven): Evaluate the need of these dependencies # TODO(Steven): Evaluate the need of these dependencies
- name: Install apt dependencies - name: Install dependencies
run: | run: |
sudo apt-get update && sudo apt-get install -y build-essential git \ if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
curl libglib2.0-0 libegl1-mesa-dev ffmpeg \ sudo apt-get update && sudo apt-get install -y build-essential \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
# Add ffmpeg@7 paths for subsequent steps
echo "PATH=/opt/homebrew/opt/ffmpeg@7/bin:$PATH" >> $GITHUB_ENV
echo "LDFLAGS=-L/opt/homebrew/opt/ffmpeg@7/lib" >> $GITHUB_ENV
echo "CPPFLAGS=-I/opt/homebrew/opt/ffmpeg@7/include" >> $GITHUB_ENV
echo "PKG_CONFIG_PATH=/opt/homebrew/opt/ffmpeg@7/lib/pkgconfig" >> $GITHUB_ENV
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python - name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with: with:

View File

@@ -51,7 +51,11 @@ jobs:
# It runs everytime a PR is approved or a push to main # It runs everytime a PR is approved or a push to main
full-tests: full-tests:
name: Full Tests name: Full Tests
runs-on: ubuntu-latest runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
if: | if: |
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved') || (github.event_name == 'pull_request_review' && github.event.review.state == 'approved') ||
github.event_name == 'push' || github.event_name == 'push' ||
@@ -64,11 +68,16 @@ jobs:
lfs: true lfs: true
persist-credentials: false persist-credentials: false
- name: Install apt dependencies - name: Install dependencies
run: | run: |
sudo apt-get update && sudo apt-get install -y build-essential \ if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \ sudo apt-get update && sudo apt-get install -y build-essential \
speech-dispatcher libgeos-dev portaudio19-dev git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python - name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]

View File

@@ -119,7 +119,6 @@ jobs:
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
container: container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
credentials: credentials:
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
@@ -159,35 +158,3 @@ jobs:
run: pytest tests -vv --maxfail=10 run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests - name: Run end-to-end tests
run: make test-end-to-end run: make test-end-to-end
# This job runs multi-GPU training tests with 4 GPUs
nightly-multi-gpu-tests:
name: Nightly Multi-GPU Tests
needs: [build-docker-gpu-nightly]
runs-on:
group: aws-g4dn-12xlarge # Instance with 4 GPUs
env:
HF_HOME: /home/user_lerobot/.cache/huggingface
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
CUDA_VISIBLE_DEVICES: "0,1,2,3"
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
defaults:
run:
shell: bash
working-directory: /lerobot
steps:
- name: Verify GPU availability
run: |
nvidia-smi
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
timeout-minutes: 10

View File

@@ -103,7 +103,7 @@ jobs:
- name: Publish to TestPyPI for pre-releases - name: Publish to TestPyPI for pre-releases
# True for tags like 'v0.2.0-rc1' # True for tags like 'v0.2.0-rc1'
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-') if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing] uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with: with:
repository-url: https://test.pypi.org/legacy/ repository-url: https://test.pypi.org/legacy/
verbose: true verbose: true
@@ -111,7 +111,7 @@ jobs:
- name: Publish to PyPI - name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing] uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with: with:
verbose: true verbose: true
print-hash: true print-hash: true
@@ -120,7 +120,11 @@ jobs:
test-release: test-release:
name: Test Release name: Test Release
needs: [build-and-publish] needs: [build-and-publish]
runs-on: ubuntu-latest runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
permissions: permissions:
contents: read contents: read
env: env:
@@ -130,15 +134,20 @@ jobs:
with: with:
lfs: true lfs: true
persist-credentials: false persist-credentials: false
- name: Install apt dependencies - name: Install dependencies
run: | run: |
sudo apt-get update && sudo apt-get install -y build-essential \ if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \ sudo apt-get update && sudo apt-get install -y build-essential \
speech-dispatcher libgeos-dev portaudio19-dev git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python - name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with: with:
enable-cache: true # zizmor: ignore[cache-poisoning] enable-cache: true
version: ${{ env.UV_VERSION }} version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }} python-version: ${{ env.PYTHON_VERSION }}
- name: Create uv virtual environment - name: Create uv virtual environment

View File

@@ -27,17 +27,15 @@ env:
This issue was closed because it has been stalled for 14 days with no activity. This issue was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
CLOSE_PR_MESSAGE: > CLOSE_PR_MESSAGE: >
This PR was closed because it has been stalled for 21 days with no activity. This PR was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
WARN_ISSUE_MESSAGE: > WARN_ISSUE_MESSAGE: >
This issue has been automatically marked as stale because it has not had This issue has been automatically marked as stale because it has not had
recent activity (6 months). It will be closed if no further activity occurs. recent activity (6 months). It will be closed if no further activity occurs.
Any change, comment or update to this issue will reset this count.
Thank you for your contributions. Thank you for your contributions.
WARN_PR_MESSAGE: > WARN_PR_MESSAGE: >
This PR has been automatically marked as stale because it has not had This PR has been automatically marked as stale because it has not had
recent activity (1 year). It will be closed if no further activity occurs. recent activity (6 months). It will be closed if no further activity occurs.
Any change, comment or update to this PR will reset this count.
Thank you for your contributions. Thank you for your contributions.
jobs: jobs:
@@ -58,10 +56,10 @@ jobs:
stale-pr-label: stale stale-pr-label: stale
exempt-issue-labels: never-stale exempt-issue-labels: never-stale
exempt-pr-labels: never-stale exempt-pr-labels: never-stale
days-before-issue-stale: 180 days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
days-before-issue-close: 14 days-before-issue-close: 14
days-before-pr-stale: 365 days-before-pr-stale: 180
days-before-pr-close: 21 days-before-pr-close: 14
delete-branch: true delete-branch: true
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }} close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }} close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}

View File

@@ -42,7 +42,11 @@ jobs:
# This job runs the E2E tests + pytest with all unbound extras # This job runs the E2E tests + pytest with all unbound extras
full-tests: full-tests:
name: Full Unbound Tests name: Full Unbound Tests
runs-on: ubuntu-latest runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
env: env:
MUJOCO_GL: egl MUJOCO_GL: egl
steps: steps:
@@ -51,11 +55,16 @@ jobs:
lfs: true lfs: true
persist-credentials: false persist-credentials: false
- name: Install apt dependencies - name: Install dependencies
run: | run: |
sudo apt-get update && sudo apt-get install -y build-essential \ if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \ sudo apt-get update && sudo apt-get install -y build-essential \
speech-dispatcher libgeos-dev portaudio19-dev git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python - name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]

View File

@@ -26,7 +26,7 @@ repos:
##### General Code Quality & Formatting ##### ##### General Code Quality & Formatting #####
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0 rev: v5.0.0
hooks: hooks:
- id: check-added-large-files - id: check-added-large-files
args: ['--maxkb=1024'] args: ['--maxkb=1024']
@@ -39,20 +39,20 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.1 rev: v0.12.4
hooks: hooks:
- id: ruff-format - id: ruff-format
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix] args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/adhtruong/mirrors-typos - repo: https://github.com/adhtruong/mirrors-typos
rev: v1.38.1 rev: v1.34.0
hooks: hooks:
- id: typos - id: typos
args: [--force-exclude] args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.21.0 rev: v3.20.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
args: [--py310-plus] args: [--py310-plus]
@@ -68,12 +68,12 @@ repos:
##### Security ##### ##### Security #####
- repo: https://github.com/gitleaks/gitleaks - repo: https://github.com/gitleaks/gitleaks
rev: v8.28.0 rev: v8.27.2
hooks: hooks:
- id: gitleaks - id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit - repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.15.2 rev: v1.11.0
hooks: hooks:
- id: zizmor - id: zizmor
@@ -87,7 +87,7 @@ repos:
# TODO(Steven): Uncomment when ready to use # TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing ##### ##### Static Analysis & Typing #####
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2 rev: v1.16.0
hooks: hooks:
- id: mypy - id: mypy
args: [--config-file=pyproject.toml] args: [--config-file=pyproject.toml]

View File

@@ -137,7 +137,7 @@ Follow these steps to start contributing:
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies. 4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already. Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda: Set up a development environment with conda or miniconda:
```bash ```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev

View File

@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
### Environment Setup ### Environment Setup
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/): Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash ```bash
conda create -y -n lerobot python=3.10 conda create -y -n lerobot python=3.10
conda activate lerobot conda activate lerobot
``` ```
When using `conda`, install `ffmpeg` in your environment: When using `miniconda`, install `ffmpeg` in your environment:
```bash ```bash
conda install ffmpeg -c conda-forge conda install ffmpeg -c conda-forge
@@ -207,13 +207,13 @@ lerobot-dataset-viz \
--episode-index 0 --episode-index 0
``` ```
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`) or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash ```bash
lerobot-dataset-viz \ lerobot-dataset-viz \
--repo-id lerobot/pusht \ --repo-id lerobot/pusht \
--root ./my_local_data_dir \ --root ./my_local_data_dir \
--mode local \ --local-files-only 1 \
--episode-index 0 --episode-index 0
``` ```
@@ -310,7 +310,7 @@ To upload these to the hub, run the following:
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
``` ```
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy. See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
### Acknowledgment ### Acknowledgment

View File

@@ -17,8 +17,6 @@
title: Train RL in Simulation title: Train RL in Simulation
- local: async - local: async
title: Use Async Inference title: Use Async Inference
- local: multi_gpu_training
title: Multi GPU training
title: "Tutorials" title: "Tutorials"
- sections: - sections:
- local: lerobot-dataset-v3 - local: lerobot-dataset-v3

View File

@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable: Then store your Hugging Face repository name in a variable:
```bash ```bash
HF_USER=$(hf auth whoami | head -n 1) HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER echo $HF_USER
``` ```

View File

@@ -1,15 +1,8 @@
# Installation # Installation
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Environment Setup ## Environment Setup
Create a virtual environment with Python 3.10, using conda: Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
```bash ```bash
conda create -y -n lerobot python=3.10 conda create -y -n lerobot python=3.10
@@ -21,7 +14,7 @@ Then activate your conda environment, you have to do this each time you open a s
conda activate lerobot conda activate lerobot
``` ```
When using `conda`, install `ffmpeg` in your environment: When using `miniconda`, install `ffmpeg` in your environment:
```bash ```bash
conda install ffmpeg -c conda-forge conda install ffmpeg -c conda-forge

View File

@@ -208,36 +208,34 @@ LeRobot supports saving and loading calibration data automatically. This is usef
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
@property > @property
def is_calibrated(self) -> bool: > def is_calibrated(self) -> bool:
return True > return True
>
def calibrate(self) -> None: > def calibrate(self) -> None:
pass > pass
``` > ```
<!-- prettier-ignore-end -->
### `is_calibrated` ### `is_calibrated`
This should reflect whether your robot has the required calibration loaded. This should reflect whether your robot has the required calibration loaded.
<!-- prettier-ignore-start --> ```
```python <!-- prettier-ignore-end -->python
@property @property
def is_calibrated(self) -> bool: def is_calibrated(self) -> bool:
return self.bus.is_calibrated return self.bus.is_calibrated
``` ```
<!-- prettier-ignore-end -->
### `calibrate()` ### `calibrate()`
The goal of the calibration is twofold: The goal of the calibration is twofold:
- Know the physical range of motion of each motors in order to only send commands within this range.
- Know the physical range of motion of each motors in order to only send commands within this range. - Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this. It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
def calibrate(self) -> None: def calibrate(self) -> None:

View File

@@ -1,125 +0,0 @@
# Multi-GPU Training
This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate).
## Installation
First, ensure you have accelerate installed:
```bash
pip install accelerate
```
## Training with Multiple GPUs
You can launch training in two ways:
### Option 1: Without config (specify parameters directly)
You can specify all parameters directly in the command without running `accelerate config`:
```bash
accelerate launch \
--multi_gpu \
--num_processes=2 \
$(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=act \
--policy.repo_id=${HF_USER}/my_trained_policy \
--output_dir=outputs/train/act_multi_gpu \
--job_name=act_multi_gpu \
--wandb.enable=true
```
**Key accelerate parameters:**
- `--multi_gpu`: Enable multi-GPU training
- `--num_processes=2`: Number of GPUs to use
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
### Option 2: Using accelerate config
If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running:
```bash
accelerate config
```
This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings:
- Compute environment: This machine
- Number of machines: 1
- Number of processes: (number of GPUs you want to use)
- GPU ids to use: (leave empty to use all)
- Mixed precision: fp16 or bf16 (recommended for faster training)
Then launch training with:
```bash
accelerate launch $(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=act \
--policy.repo_id=${HF_USER}/my_trained_policy \
--output_dir=outputs/train/act_multi_gpu \
--job_name=act_multi_gpu \
--wandb.enable=true
```
## How It Works
When you launch training with accelerate:
1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate
2. **Data distribution**: Your batch is automatically split across GPUs
3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation
4. **Single process logging**: Only the main process logs to wandb and saves checkpoints
## Learning Rate and Training Steps Scaling
**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters.
### Why No Automatic Scaling?
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
However, LeRobot keeps the learning rate exactly as you specify it.
### When and How to Scale
If you want to scale your hyperparameters when using multiple GPUs, you should do it manually:
**Learning Rate Scaling:**
```bash
# Example: 2 GPUs with linear LR scaling
# Base LR: 1e-4, with 2 GPUs -> 2e-4
accelerate launch --num_processes=2 $(which lerobot-train) \
--optimizer.lr=2e-4 \
--dataset.repo_id=lerobot/pusht \
--policy=act
```
**Training Steps Scaling:**
Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
```bash
# Example: 2 GPUs with effective batch size 2x larger
# Original: batch_size=8, steps=100000
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
accelerate launch --num_processes=2 $(which lerobot-train) \
--batch_size=8 \
--steps=50000 \
--dataset.repo_id=lerobot/pusht \
--policy=act
```
## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).

View File

@@ -1,112 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script demonstrates how to evaluate pretrained vision-language-action (VLA) policies
such as SmolVLA on Libero benchmark tasks using the LeRobot framework.
It showcases the full evaluation pipeline — from environment creation to policy inference,
visualization, and result logging — and is intended as a reference for benchmarking or
integrating new robotic policies.
Features included in this script:
- loading Libero environments (e.g., libero_spatial, libero_object) via `make_env`.
- initializing pretrained policies (e.g., SmolVLA) from Hugging Face using `make_policy`.
- applying preprocessing and postprocessing transformations for model compatibility.
- running evaluation rollouts and recording rendered frames from the simulator.
- computing success metrics and saving rollout videos as MP4 for qualitative analysis.
The script ends by saving a rollout video (`rollout.mp4`) and printing per-environment
success indicators for quick visual and numerical evaluation.
"""
import numpy as np
import torch
import imageio.v2 as imageio
from lerobot.envs.factory import make_env, make_env_config
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.factory import make_policy_config
from lerobot.envs.utils import (
add_envs_task,
preprocess_observation,
)
import os
os.environ["MUJOCO_GL"] = "egl"
SMOLVLA_LIBERO_PATH = "HuggingFaceVLA/smolvla_libero"
LIBERO_CONFIG = make_env_config("libero", task="libero_spatial")
breakpoint()
POLICY_CONFIG = make_policy_config("smolvla", pretrained_path=SMOLVLA_LIBERO_PATH)
policy = make_policy(
cfg=POLICY_CONFIG,
env_cfg=LIBERO_CONFIG,
)
breakpoint()
libero_env = make_env(LIBERO_CONFIG)
breakpoint()
print(type(libero_env)) # <class 'dict'>
print(libero_env.keys()) # dict_keys(['libero_spatial', 'libero_object'])
# initilize your policy, here we use smolvla
breakpoint()
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=POLICY_CONFIG,
pretrained_path=SMOLVLA_LIBERO_PATH,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
policy.reset()
# for the sake of this exemple we only use one env from each task
libero_spatial_env = libero_env['libero_spatial'][0]
# libero_object_env = libero_env['libero_object'][0]
# let's first run an evaluation throgut the first task
observation, info = libero_spatial_env.reset() # you can pass seeds
max_steps = 220
step = 0
all_images = []
done = np.array([False] * libero_spatial_env.num_envs)
while not np.all(done) and step < max_steps:
observation = preprocess_observation(observation)
observation = add_envs_task(libero_spatial_env, observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
# Convert to CPU / numpy.
action_numpy = action.to("cpu").numpy()
# Apply the next action.
# let's render the video
image = libero_spatial_env.call("render")[0]
all_images.append(image)
observation, reward, terminated, truncated, info = libero_spatial_env.step(action_numpy)
if "final_info" in info:
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
)
successes = final_info["is_success"].tolist()
else:
successes = [False] * libero_spatial_env.num_envs
done = terminated | truncated | done
if step + 1 == max_steps:
done = np.ones_like(done, dtype=bool)
step += 1
print("The success: ", successes)

View File

@@ -62,7 +62,6 @@ dependencies = [
"datasets>=4.0.0,<4.2.0", "datasets>=4.0.0,<4.2.0",
"diffusers>=0.27.2,<0.36.0", "diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies # Core dependencies
"setuptools>=71.0.0,<81.0.0", "setuptools>=71.0.0,<81.0.0",
@@ -74,7 +73,7 @@ dependencies = [
"packaging>=24.2,<26.0", "packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0", "pynput>=1.7.7,<1.9.0",
"pyserial>=3.5,<4.0", "pyserial>=3.5,<4.0",
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf) "wandb>=0.20.0,<0.23.0",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
@@ -82,7 +81,7 @@ dependencies = [
"draccus==0.10.0", # TODO: Remove == "draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.0.0", "gymnasium>=1.0.0",
"rerun-sdk>=0.24.0,<0.27.0", "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
# Support dependencies # Support dependencies
"deepdiff>=7.0.1,<9.0.0", "deepdiff>=7.0.1,<9.0.0",
@@ -97,7 +96,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1,<2.7.0"] pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"] placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"] transformers-dep = ["transformers>=4.53.0,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb) grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors # Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -114,6 +113,11 @@ intelrealsense = [
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'", "pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
] ]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"] phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
# stretch = [
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
# ] # TODO: Currently not supported
# Policies # Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
@@ -229,6 +233,9 @@ exclude_dirs = [
"tests", "tests",
"benchmarks", "benchmarks",
"src/lerobot/datasets/push_dataset_to_hub", "src/lerobot/datasets/push_dataset_to_hub",
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
"src/lerobot/policies/pi0/conversion_scripts",
"src/lerobot/scripts/push_dataset_to_hub.py",
] ]
skips = ["B101", "B311", "B404", "B603", "B615"] skips = ["B101", "B311", "B404", "B603", "B615"]
@@ -243,7 +250,6 @@ default.extend-ignore-identifiers-re = [
"pn", "pn",
"ser", "ser",
"ein", "ein",
"inpt",
] ]
# TODO: Uncomment when ready to use # TODO: Uncomment when ready to use
@@ -282,6 +288,7 @@ ignore_errors = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "lerobot.envs.*" module = "lerobot.envs.*"
# Enable type checking only for the envs module
ignore_errors = false ignore_errors = false
@@ -297,9 +304,9 @@ ignore_errors = false
# module = "lerobot.optim.*" # module = "lerobot.optim.*"
# ignore_errors = false # ignore_errors = false
[[tool.mypy.overrides]] # [[tool.mypy.overrides]]
module = "lerobot.model.*" # module = "lerobot.model.*"
ignore_errors = false # ignore_errors = false
# [[tool.mypy.overrides]] # [[tool.mypy.overrides]]
# module = "lerobot.processor.*" # module = "lerobot.processor.*"

View File

@@ -842,7 +842,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Get available episode indices from cached dataset # Get available episode indices from cached dataset
available_episodes = { available_episodes = {
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
for ep_idx in self.hf_dataset.unique("episode_index") for ep_idx in self.hf_dataset["episode_index"]
} }
# Determine requested episodes # Determine requested episodes

View File

@@ -206,11 +206,6 @@ class ImageTransformsConfig:
type="SharpnessJitter", type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)}, kwargs={"sharpness": (0.5, 1.5)},
), ),
"affine": ImageTransformConfig(
weight=1.0,
type="RandomAffine",
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
),
} }
) )
@@ -222,8 +217,6 @@ def make_transform_from_config(cfg: ImageTransformConfig):
return v2.ColorJitter(**cfg.kwargs) return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter": elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs) return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else: else:
raise ValueError(f"Transform '{cfg.type}' is not valid.") raise ValueError(f"Transform '{cfg.type}' is not valid.")

View File

@@ -98,7 +98,7 @@ OLD
videos/chunk-000/CAMERA/episode_000000.mp4 videos/chunk-000/CAMERA/episode_000000.mp4
NEW NEW
videos/CAMERA/chunk-000/file_000.mp4 videos/chunk-000/file_000.mp4
------------------------- -------------------------
OLD OLD
episodes.jsonl episodes.jsonl

View File

@@ -342,8 +342,8 @@ def encode_video_frames(
# Define video output frame size (assuming all input frames are the same size) # Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0: if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.") raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image: dummy_image = Image.open(input_list[0])
width, height = dummy_image.size width, height = dummy_image.size
# Define video codec options # Define video codec options
video_options = {} video_options = {}
@@ -373,12 +373,11 @@ def encode_video_frames(
# Loop through input frames and encode them # Loop through input frames and encode them
for input_data in input_list: for input_data in input_list:
with Image.open(input_data) as input_image: input_image = Image.open(input_data).convert("RGB")
input_image = input_image.convert("RGB") input_frame = av.VideoFrame.from_image(input_image)
input_frame = av.VideoFrame.from_image(input_image) packet = output_stream.encode(input_frame)
packet = output_stream.encode(input_frame) if packet:
if packet: output.mux(packet)
output.mux(packet)
# Flush the encoder # Flush the encoder
packet = output_stream.encode() packet = output_stream.encode()

View File

@@ -37,16 +37,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str: def type(self) -> str:
return self.get_choice_name(self.__class__) return self.get_choice_name(self.__class__)
@property
def package_name(self) -> str:
"""Package name to import if environment not found in gym registry"""
return f"gym_{self.type}"
@property
def gym_id(self) -> str:
"""ID string used in gym.make() to instantiate the environment"""
return f"{self.package_name}/{self.task}"
@property @property
@abc.abstractmethod @abc.abstractmethod
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:

View File

@@ -16,7 +16,6 @@
import importlib import importlib
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
@@ -85,24 +84,17 @@ def make_env(
gym_kwargs=cfg.gym_kwargs, gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls, env_cls=env_cls,
) )
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
if cfg.gym_id not in gym_registry: gym_handle = f"{package_name}/{cfg.task}"
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
def _make_one(): def _make_one():
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP) vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)

View File

@@ -22,18 +22,18 @@ class RobotKinematics:
self, self,
urdf_path: str, urdf_path: str,
target_frame_name: str = "gripper_frame_link", target_frame_name: str = "gripper_frame_link",
joint_names: list[str] | None = None, joint_names: list[str] = None,
): ):
""" """
Initialize placo-based kinematics solver. Initialize placo-based kinematics solver.
Args: Args:
urdf_path (str): Path to the robot URDF file urdf_path: Path to the robot URDF file
target_frame_name (str): Name of the end-effector frame in the URDF target_frame_name: Name of the end-effector frame in the URDF
joint_names (list[str] | None): List of joint names to use for the kinematics solver joint_names: List of joint names to use for the kinematics solver
""" """
try: try:
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support. import placo
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"placo is required for RobotKinematics. " "placo is required for RobotKinematics. "
@@ -52,7 +52,7 @@ class RobotKinematics:
# Initialize frame task for IK # Initialize frame task for IK
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4)) self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray: def forward_kinematics(self, joint_pos_deg):
""" """
Compute forward kinematics for given joint configuration given the target frame name in the constructor. Compute forward kinematics for given joint configuration given the target frame name in the constructor.
@@ -77,12 +77,8 @@ class RobotKinematics:
return self.robot.get_T_world_frame(self.target_frame_name) return self.robot.get_T_world_frame(self.target_frame_name)
def inverse_kinematics( def inverse_kinematics(
self, self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
current_joint_pos: np.ndarray, ):
desired_ee_pose: np.ndarray,
position_weight: float = 1.0,
orientation_weight: float = 0.01,
) -> np.ndarray:
""" """
Compute inverse kinematics using placo solver. Compute inverse kinematics using placo solver.

View File

@@ -60,7 +60,7 @@ class OperatingMode(Enum):
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing # This mode controls position. This mode is identical to the Multi-turn Position Control from existing
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or # DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
# conveyor systems or a system that requires an additional reduction gear. Note that Max Position # conveyer systems or a system that requires an additional reduction gear. Note that Max Position
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode. # Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
EXTENDED_POSITION = 4 EXTENDED_POSITION = 4

View File

@@ -206,12 +206,8 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits # Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = { STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11, "Homing_Offset": 11,
"Goal_Position": 15,
"Goal_Velocity": 15, "Goal_Velocity": 15,
"Goal_Speed": 15,
"Present_Position": 15,
"Present_Velocity": 15, "Present_Velocity": 15,
"Present_Speed": 15,
} }
MODEL_ENCODING_TABLE = { MODEL_ENCODING_TABLE = {

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc import abc
import logging
import math import math
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from pathlib import Path from pathlib import Path
@@ -80,11 +79,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup") @LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass @dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by Physical Intelligence to train Pi0. """Used by Physical Intelligence to train Pi0"""
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
This ensures the learning rate schedule completes properly even with shorter training runs.
"""
num_warmup_steps: int num_warmup_steps: int
num_decay_steps: int num_decay_steps: int
@@ -92,39 +87,23 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
decay_lr: float decay_lr: float
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps del num_training_steps
actual_warmup_steps = self.num_warmup_steps
actual_decay_steps = self.num_decay_steps
if num_training_steps < self.num_decay_steps:
# Calculate scaling factor to fit the schedule into the available training steps
scale_factor = num_training_steps / self.num_decay_steps
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
actual_decay_steps = num_training_steps
logging.info(
f"Auto-scaling LR scheduler: "
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
f"Scaling warmup: {self.num_warmup_steps}{actual_warmup_steps}, "
f"decay: {self.num_decay_steps}{actual_decay_steps} "
f"(scale factor: {scale_factor:.3f})"
)
def lr_lambda(current_step): def lr_lambda(current_step):
def linear_warmup_schedule(current_step): def linear_warmup_schedule(current_step):
if current_step <= 0: if current_step <= 0:
return 1 / (actual_warmup_steps + 1) return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / actual_warmup_steps frac = 1 - current_step / self.num_warmup_steps
return (1 / (actual_warmup_steps + 1) - 1) * frac + 1 return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
def cosine_decay_schedule(current_step): def cosine_decay_schedule(current_step):
step = min(current_step, actual_decay_steps) step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps)) cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha decayed = (1 - alpha) * cosine_decay + alpha
return decayed return decayed
if current_step < actual_warmup_steps: if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step) return linear_warmup_schedule(current_step)
return cosine_decay_schedule(current_step) return cosine_decay_schedule(current_step)

View File

@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
x: (Decoder Sequence, Batch, Channel) tensor of input tokens. x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with. cross-attending with.
encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder). encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns: Returns:
(DS, B, C) tensor of decoder output features. (DS, B, C) tensor of decoder output features.
""" """

View File

@@ -360,12 +360,10 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided") raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg) features = env_to_policy_features(env_cfg)
if not cfg.output_features: cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg kwargs["config"] = cfg
breakpoint()
if cfg.pretrained_path: if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary). # hyperparameters that we want to vary).

View File

@@ -75,8 +75,6 @@ class PI0Config(PreTrainedConfig):
optimizer_grad_clip_norm: float = 1.0 optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule` # Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 1_000 scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000 scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6 scheduler_decay_lr: float = 2.5e-6

View File

@@ -75,8 +75,6 @@ class PI05Config(PreTrainedConfig):
optimizer_grad_clip_norm: float = 1.0 optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule` # Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 1_000 scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000 scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6 scheduler_decay_lr: float = 2.5e-6

View File

@@ -65,7 +65,7 @@ def main(cfg: TrainRLServerPipelineConfig):
# env_cfg=cfg.env, # env_cfg=cfg.env,
ds_meta=dataset_meta, ds_meta=dataset_meta,
) )
policy = policy.from_pretrained(env_cfg.pretrained_policy_name_or_path) policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
policy.eval() policy.eval()
eval_policy(env, policy=policy, n_episodes=10) eval_policy(env, policy=policy, n_episodes=10)

View File

@@ -99,7 +99,7 @@ class WandBLogger:
cfg.wandb.run_id = run_id cfg.wandb.run_id = run_id
# Handle custom step key for rl asynchronous training. # Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None self._wandb_custom_step_key: set[str] | None = None
logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb self._wandb = wandb

View File

@@ -153,7 +153,7 @@ class LeKiwi(Robot):
homing_offsets.update(dict.fromkeys(self.base_motors, 0)) homing_offsets.update(dict.fromkeys(self.base_motors, 0))
full_turn_motor = [ full_turn_motor = [
motor for motor in motors if any(keyword in motor for keyword in ["wheel", "wrist_roll"]) motor for motor in motors if any(keyword in motor for keyword in ["wheel", "wrist"])
] ]
unknown_range_motors = [motor for motor in motors if motor not in full_turn_motor] unknown_range_motors = [motor for motor in motors if motor not in full_turn_motor]

View File

@@ -18,24 +18,14 @@ import base64
import json import json
import logging import logging
import time import time
from dataclasses import dataclass, field
import cv2 import cv2
import draccus
import zmq import zmq
from .config_lekiwi import LeKiwiConfig, LeKiwiHostConfig from .config_lekiwi import LeKiwiConfig, LeKiwiHostConfig
from .lekiwi import LeKiwi from .lekiwi import LeKiwi
@dataclass
class LeKiwiServerConfig:
"""Configuration for the LeKiwi host script."""
robot: LeKiwiConfig = field(default_factory=LeKiwiConfig)
host: LeKiwiHostConfig = field(default_factory=LeKiwiHostConfig)
class LeKiwiHost: class LeKiwiHost:
def __init__(self, config: LeKiwiHostConfig): def __init__(self, config: LeKiwiHostConfig):
self.zmq_context = zmq.Context() self.zmq_context = zmq.Context()
@@ -57,16 +47,17 @@ class LeKiwiHost:
self.zmq_context.term() self.zmq_context.term()
@draccus.wrap() def main():
def main(cfg: LeKiwiServerConfig):
logging.info("Configuring LeKiwi") logging.info("Configuring LeKiwi")
robot = LeKiwi(cfg.robot) robot_config = LeKiwiConfig()
robot = LeKiwi(robot_config)
logging.info("Connecting LeKiwi") logging.info("Connecting LeKiwi")
robot.connect() robot.connect()
logging.info("Starting HostAgent") logging.info("Starting HostAgent")
host = LeKiwiHost(cfg.host) host_config = LeKiwiHostConfig()
host = LeKiwiHost(host_config)
last_cmd_time = time.time() last_cmd_time = time.time()
watchdog_active = False watchdog_active = False

View File

@@ -0,0 +1,173 @@
This tutorial explains how to use [Stretch 3](https://hello-robot.com/stretch-3-product) with LeRobot.
## Setup
Familiarize yourself with Stretch by following its [tutorials](https://docs.hello-robot.com/0.3/getting_started/hello_robot/) (recommended).
To use LeRobot on Stretch, 3 options are available:
- [tethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup)
- [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup)
- ssh directly into Stretch (you will first need to install and configure openssh-server on stretch using one of the two above setups)
## Install LeRobot
On Stretch's CLI, follow these steps:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Comment out these lines in `~/.profile` (this can mess up paths used by conda and ~/.local/bin should already be in your PATH)
```
# set PATH so it includes user's private bin if it exists
if [ -d "$HOME/.local/bin" ] ; then
PATH="$HOME/.local/bin:$PATH"
fi
```
3. Restart shell or `source ~/.bashrc`
4. Create and activate a fresh conda environment for lerobot
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
5. Clone LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
6. When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
```
7. Install LeRobot with stretch dependencies:
```bash
cd ~/lerobot && pip install -e ".[stretch]"
```
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
```bash
stretch_system_check.py
```
> **Note:** You may need to free the "robot process" after booting Stretch by running `stretch_free_robot_process.py`. For more info this Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#turning-off-gamepad-teleoperation).
You should get something like this:
```bash
For use with S T R E T C H (R) from Hello Robot Inc.
---------------------------------------------------------------------
Model = Stretch 3
Tool = DexWrist 3 w/ Gripper
Serial Number = stretch-se3-3054
---- Checking Hardware ----
[Pass] Comms are ready
[Pass] Actuators are ready
[Warn] Sensors not ready (IMU AZ = -10.19 out of range -10.1 to -9.5)
[Pass] Battery voltage is 13.6 V
---- Checking Software ----
[Pass] Ubuntu 22.04 is ready
[Pass] All APT pkgs are setup correctly
[Pass] Firmware is up-to-date
[Pass] Python pkgs are up-to-date
[Pass] ROS2 Humble is ready
```
## Teleoperate, record a dataset and run a policy
**Calibrate (Optional)**
Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=calibrate
```
This is equivalent to running `stretch_robot_home.py`
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
**Teleoperate**
Before trying teleoperation, you need to activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
Now try out teleoperation (see above documentation to learn about the gamepad controls):
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=teleoperate
```
This is essentially the same as running `stretch_gamepad_teleop.py`
**Record a dataset**
Once you're familiar with the gamepad controls and after a bit of practice, you can try to record your first dataset with Stretch.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record one episode:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/stretch_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
```
> **Note:** If you're using ssh to connect to Stretch and run this script, you won't be able to visualize its cameras feed (though they will still be recording). To see the cameras stream, use [tethered](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup) or [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup).
**Replay an episode**
Now try to replay this episode (make sure the robot's initial position is the same):
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/stretch_test \
--control.episode=0
```
If you need help, please reach out on Discord in the channel `#stretch3-mobile-arm`.

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_stretch3 import Stretch3RobotConfig
from .robot_stretch3 import Stretch3Robot

View File

@@ -0,0 +1,51 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.cameras.realsense import RealSenseCameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("stretch3")
@dataclass
class Stretch3RobotConfig(RobotConfig):
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"navigation": OpenCVCameraConfig(
index_or_path="/dev/hello-nav-head-camera",
fps=10,
width=1280,
height=720,
rotation=-90,
),
"head": RealSenseCameraConfig(
name="Intel RealSense D435I",
fps=30,
width=640,
height=480,
rotation=90,
),
"wrist": RealSenseCameraConfig(
name="Intel RealSense D405",
fps=30,
width=640,
height=480,
),
}
)

View File

@@ -0,0 +1,180 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import numpy as np
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.datasets.utils import get_nested_item
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from ..robot import Robot
from .configuration_stretch3 import Stretch3RobotConfig
# {lerobot_keys: stretch.api.keys}
STRETCH_MOTORS = {
"head_pan.pos": "head.head_pan.pos",
"head_tilt.pos": "head.head_tilt.pos",
"lift.pos": "lift.pos",
"arm.pos": "arm.pos",
"wrist_pitch.pos": "end_of_arm.wrist_pitch.pos",
"wrist_roll.pos": "end_of_arm.wrist_roll.pos",
"wrist_yaw.pos": "end_of_arm.wrist_yaw.pos",
"gripper.pos": "end_of_arm.stretch_gripper.pos",
"base_x.vel": "base.x_vel",
"base_y.vel": "base.y_vel",
"base_theta.vel": "base.theta_vel",
}
class Stretch3Robot(Robot):
"""[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
config_class = Stretch3RobotConfig
name = "stretch3"
def __init__(self, config: Stretch3RobotConfig):
raise NotImplementedError
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.api = StretchAPI()
self.cameras = make_cameras_from_configs(config.cameras)
self.is_connected = False
self.logs = {}
self.teleop = None # TODO remove
# TODO(aliberts): test this
RobotParams.set_logging_level("WARNING")
RobotParams.set_logging_formatter("brief_console_formatter")
self.state_keys = None
self.action_keys = None
@property
def observation_features(self) -> dict:
return {
"dtype": "float32",
"shape": (len(STRETCH_MOTORS),),
"names": {"motors": list(STRETCH_MOTORS)},
}
@property
def action_features(self) -> dict:
return self.observation_features
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
def connect(self) -> None:
self.is_connected = self.api.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError()
for cam in self.cameras.values():
cam.connect()
self.is_connected = self.is_connected and cam.is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.calibrate()
def calibrate(self) -> None:
if not self.api.is_homed():
self.api.home()
def _get_state(self) -> dict:
status = self.api.get_status()
return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict = {}
# Read Stretch state
before_read_t = time.perf_counter()
state = self._get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
self.state_keys = list(state)
state = np.asarray(list(state.values()))
obs_dict[OBS_STATE] = state
# Capture images from cameras
for cam_key, cam in self.cameras.items():
before_camread_t = time.perf_counter()
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
return obs_dict
def send_action(self, action: np.ndarray) -> np.ndarray:
if not self.is_connected:
raise ConnectionError()
if self.teleop is None:
self.teleop = GamePadTeleop(robot_instance=False)
self.teleop.startup(robot=self)
if self.action_keys is None:
dummy_action = self.teleop.gamepad_controller.get_state()
self.action_keys = list(dummy_action.keys())
action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
before_write_t = time.perf_counter()
self.teleop.do_motion(state=action_dict, robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
# TODO(aliberts): return action_sent when motion is limited
return action
def teleop_safety_stop(self) -> None:
if self.teleop is not None:
self.teleop._safety_stop(robot=self)
def disconnect(self) -> None:
self.api.stop()
if self.teleop is not None:
self.teleop.gamepad_controller.stop()
self.teleop.stop()
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False

View File

@@ -40,6 +40,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .lekiwi import LeKiwi from .lekiwi import LeKiwi
return LeKiwi(config) return LeKiwi(config)
elif config.type == "stretch3":
from .stretch3 import Stretch3Robot
return Stretch3Robot(config)
elif config.type == "viperx":
from .viperx import ViperX
return ViperX(config)
elif config.type == "hope_jr_hand": elif config.type == "hope_jr_hand":
from .hope_jr import HopeJrHand from .hope_jr import HopeJrHand

View File

@@ -0,0 +1,196 @@
This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot.
## Setup
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
## Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Restart shell or `source ~/.bashrc`
3. Create and activate a fresh conda environment for lerobot
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
4. Clone LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
```
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
```bash
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
```
## Teleoperate
\*\*/!\ FOR SAFETY, READ THIS /!\*\*
Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly:
1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms,
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
By running the following code, you can start your first **SAFE** teleoperation:
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=5 \
--control.type=teleoperate
```
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`ViperXConfig`](./config_viperx.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=null \
--control.type=teleoperate
```
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with Aloha.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=null \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/aloha_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
```
## Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
echo ${HF_USER}/aloha_test
```
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun):
```bash
lerobot-dataset-viz \
--repo-id ${HF_USER}/aloha_test --episode 0
```
## Replay an episode
\*\*/!\ FOR SAFETY, READ THIS /!\*\*
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above.
Now try to replay the first episode on your robot:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=null \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/aloha_test \
--control.episode=0
```
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/aloha_test \
--policy.type=act \
--output_dir=outputs/train/act_aloha_test \
--job_name=act_aloha_test \
--policy.device=cuda \
--wandb.enable=true
```
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../src/lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/eval_act_aloha_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \
--control.num_image_writer_processes=1
```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
## More
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_viperx import ViperXConfig
from .viperx import ViperX

View File

@@ -0,0 +1,45 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("viperx")
@dataclass
class ViperXConfig(RobotConfig):
port: str # Port to connect to the arm
disable_torque_on_disconnect: bool = True
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: float | dict[str, float] = 5.0
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Troubleshooting: If one of your IntelRealSense cameras freeze during
# data recording due to bandwidth limit, you might need to plug the camera
# on another USB hub or PCIe card.

View File

@@ -0,0 +1,233 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_viperx import ViperXConfig
logger = logging.getLogger(__name__)
class ViperX(Robot):
"""
[ViperX](https://www.trossenrobotics.com/viperx-300) developed by Trossen Robotics
"""
config_class = ViperXConfig
name = "viperx"
def __init__(
self,
config: ViperXConfig,
):
raise NotImplementedError
super().__init__(config)
self.config = config
self.bus = DynamixelMotorsBus(
port=self.config.port,
motors={
"waist": Motor(1, "xm540-w270", MotorNormMode.RANGE_M100_100),
"shoulder": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100),
"shoulder_shadow": Motor(3, "xm540-w270", MotorNormMode.RANGE_M100_100),
"elbow": Motor(4, "xm540-w270", MotorNormMode.RANGE_M100_100),
"elbow_shadow": Motor(5, "xm540-w270", MotorNormMode.RANGE_M100_100),
"forearm_roll": Motor(6, "xm540-w270", MotorNormMode.RANGE_M100_100),
"wrist_angle": Motor(7, "xm540-w270", MotorNormMode.RANGE_M100_100),
"wrist_rotate": Motor(8, "xm430-w350", MotorNormMode.RANGE_M100_100),
"gripper": Motor(9, "xm430-w350", MotorNormMode.RANGE_0_100),
},
)
self.cameras = make_cameras_from_configs(config.cameras)
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
def connect(self, calibrate: bool = True) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
raise NotImplementedError # TODO(aliberts): adapt code below (copied from koch
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
input("Move robot to the middle of its range of motion and press ENTER....")
homing_offsets = self.bus.set_half_turn_homings()
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors]
print(
f"Move all joints except {full_turn_motors} sequentially through their entire "
"ranges of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors)
for motor in full_turn_motors:
range_mins[motor] = 0
range_maxes[motor] = 4095
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=homing_offsets[motor],
range_min=range_mins[motor],
range_max=range_maxes[motor],
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
with self.bus.torque_disabled():
self.bus.configure_motors()
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
# As a result, if only one of them is required to move to a certain position,
# the other will follow. This is to avoid breaking the motors.
self.bus.write("Secondary_ID", "shoulder_shadow", 2)
self.bus.write("Secondary_ID", "elbow_shadow", 4)
# Set a velocity limit of 131 as advised by Trossen Robotics
# TODO(aliberts): remove as it's actually useless in position control
self.bus.write("Velocity_Limit", 131)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point.
# See: https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11
for motor in self.bus.motors:
if motor != "gripper":
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for follower gripper to be limited by the limit of the
# current. It can grasp an object without forcing too much even tho, it's goal position is a
# complete grasp (both gripper fingers are ordered to join and reach a touch).
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
def get_observation(self) -> dict[str, Any]:
"""The returned observations do not have a batch dimension."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict = {}
# Read arm position
start = time.perf_counter()
obs_dict[OBS_STATE] = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, float]) -> dict[str, float]:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Args:
action (dict[str, float]): The goal positions for the motors.
Returns:
dict[str, float]: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.bus.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# Send goal position to the arm
self.bus.sync_write("Goal_Position", goal_pos)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -141,15 +141,15 @@ def visualize_dataset(
gc.collect() gc.collect()
if mode == "distant": if mode == "distant":
rr.serve_web_viewer(open_browser=False, web_port=web_port) rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun") logging.info("Logging to Rerun")
for batch in tqdm.tqdm(dataloader, total=len(dataloader)): for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch # iterate over the batch
for i in range(len(batch["index"])): for i in range(len(batch["index"])):
rr.set_time("frame_index", sequence=batch["frame_index"][i].item()) rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item()) rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image # display each camera image
for key in dataset.meta.camera_keys: for key in dataset.meta.camera_keys:
@@ -159,21 +159,21 @@ def visualize_dataset(
# display each dimension of action space (e.g. actuators command) # display each dimension of action space (e.g. actuators command)
if ACTION in batch: if ACTION in batch:
for dim_idx, val in enumerate(batch[ACTION][i]): for dim_idx, val in enumerate(batch[ACTION][i]):
rr.log(f"{ACTION}/{dim_idx}", rr.Scalars(val.item())) rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space) # display each dimension of observed state space (e.g. agent position in joint space)
if OBS_STATE in batch: if OBS_STATE in batch:
for dim_idx, val in enumerate(batch[OBS_STATE][i]): for dim_idx, val in enumerate(batch[OBS_STATE][i]):
rr.log(f"state/{dim_idx}", rr.Scalars(val.item())) rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if DONE in batch: if DONE in batch:
rr.log(DONE, rr.Scalars(batch[DONE][i].item())) rr.log(DONE, rr.Scalar(batch[DONE][i].item()))
if REWARD in batch: if REWARD in batch:
rr.log(REWARD, rr.Scalars(batch[REWARD][i].item())) rr.log(REWARD, rr.Scalar(batch[REWARD][i].item()))
if "next.success" in batch: if "next.success" in batch:
rr.log("next.success", rr.Scalars(batch["next.success"][i].item())) rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
if mode == "local" and save: if mode == "local" and save:
# save .rrd locally # save .rrd locally

View File

@@ -26,8 +26,8 @@ lerobot-eval \
--env.type=pusht \ --env.type=pusht \
--eval.batch_size=10 \ --eval.batch_size=10 \
--eval.n_episodes=10 \ --eval.n_episodes=10 \
--policy.use_amp=false \ --use_amp=false \
--policy.device=cuda --device=cuda
``` ```
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
@@ -37,8 +37,8 @@ lerobot-eval \
--env.type=pusht \ --env.type=pusht \
--eval.batch_size=10 \ --eval.batch_size=10 \
--eval.n_episodes=10 \ --eval.n_episodes=10 \
--policy.use_amp=false \ --use_amp=false \
--policy.device=cuda --device=cuda
``` ```
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files. Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
@@ -502,6 +502,7 @@ def eval_main(cfg: EvalPipelineConfig):
cfg=cfg.policy, cfg=cfg.policy,
env_cfg=cfg.env, env_cfg=cfg.env,
) )
policy.eval() policy.eval()
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,

View File

@@ -180,7 +180,7 @@ def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None:
if instance: if instance:
logger.info(f"Connecting to {cam_type} camera: {cam_id}...") logger.info(f"Connecting to {cam_type} camera: {cam_id}...")
instance.connect(warmup=True) instance.connect(warmup=False)
return {"instance": instance, "meta": cam_meta} return {"instance": instance, "meta": cam_meta}
except Exception as e: except Exception as e:
logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}") logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}")

View File

@@ -20,8 +20,8 @@ from pprint import pformat
from typing import Any from typing import Any
import torch import torch
from accelerate import Accelerator
from termcolor import colored from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer from torch.optim import Optimizer
from lerobot.configs import parser from lerobot.configs import parser
@@ -34,6 +34,7 @@ from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.wandb_utils import WandBLogger from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
@@ -47,6 +48,7 @@ from lerobot.utils.train_utils import (
) )
from lerobot.utils.utils import ( from lerobot.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device,
has_method, has_method,
init_logging, init_logging,
) )
@@ -58,15 +60,16 @@ def update_policy(
batch: Any, batch: Any,
optimizer: Optimizer, optimizer: Optimizer,
grad_clip_norm: float, grad_clip_norm: float,
accelerator: Accelerator, grad_scaler: GradScaler,
lr_scheduler=None, lr_scheduler=None,
use_amp: bool = False,
lock=None, lock=None,
) -> tuple[MetricsTracker, dict]: ) -> tuple[MetricsTracker, dict]:
""" """
Performs a single training step to update the policy's weights. Performs a single training step to update the policy's weights.
This function executes the forward and backward passes, clips gradients, and steps the optimizer and This function executes the forward and backward passes, clips gradients, and steps the optimizer and
learning rate scheduler. Accelerator handles mixed-precision training automatically. learning rate scheduler. It also handles mixed-precision training via a GradScaler.
Args: Args:
train_metrics: A MetricsTracker instance to record training statistics. train_metrics: A MetricsTracker instance to record training statistics.
@@ -74,8 +77,9 @@ def update_policy(
batch: A batch of training data. batch: A batch of training data.
optimizer: The optimizer used to update the policy's parameters. optimizer: The optimizer used to update the policy's parameters.
grad_clip_norm: The maximum norm for gradient clipping. grad_clip_norm: The maximum norm for gradient clipping.
accelerator: The Accelerator instance for distributed training and mixed precision. grad_scaler: The GradScaler for automatic mixed-precision training.
lr_scheduler: An optional learning rate scheduler. lr_scheduler: An optional learning rate scheduler.
use_amp: A boolean indicating whether to use automatic mixed precision.
lock: An optional lock for thread-safe optimizer updates. lock: An optional lock for thread-safe optimizer updates.
Returns: Returns:
@@ -84,27 +88,28 @@ def update_policy(
- A dictionary of outputs from the policy's forward pass, for logging purposes. - A dictionary of outputs from the policy's forward pass, for logging purposes.
""" """
start_time = time.perf_counter() start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train() policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Use accelerator's backward method # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
accelerator.backward(loss) grad_scaler.unscale_(optimizer)
# Clip gradients if specified grad_norm = torch.nn.utils.clip_grad_norm_(
if grad_clip_norm > 0: policy.parameters(),
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) grad_clip_norm,
else: error_if_nonfinite=False,
grad_norm = torch.nn.utils.clip_grad_norm_( )
policy.parameters(), float("inf"), error_if_nonfinite=False
)
# Optimizer step # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext(): with lock if lock is not None else nullcontext():
optimizer.step() grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
@@ -112,9 +117,9 @@ def update_policy(
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
# Update internal buffers if policy has update method if has_method(policy, "update"):
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() policy.update()
train_metrics.loss = loss.item() train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item() train_metrics.grad_norm = grad_norm.item()
@@ -124,7 +129,7 @@ def update_policy(
@parser.wrap() @parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): def train(cfg: TrainPipelineConfig):
""" """
Main function to train a policy. Main function to train a policy.
@@ -138,76 +143,41 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
Args: Args:
cfg: A `TrainPipelineConfig` object containing all training configurations. cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
""" """
cfg.validate() cfg.validate()
logging.info(pformat(cfg.to_dict()))
# Create Accelerator if not provided if cfg.wandb.enable and cfg.wandb.project:
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
init_logging(accelerator=accelerator)
# Determine if this is the main process (for logging and checkpointing)
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process
# Only log on main process
if is_main_process:
logging.info(pformat(cfg.to_dict()))
# Initialize wandb only on main process
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
wandb_logger = WandBLogger(cfg) wandb_logger = WandBLogger(cfg)
else: else:
wandb_logger = None wandb_logger = None
if is_main_process: logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None: if cfg.seed is not None:
set_seed(cfg.seed, accelerator=accelerator) set_seed(cfg.seed)
# Use accelerator's device # Check device is available
device = accelerator.device device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: main process downloads first to avoid race conditions logging.info("Creating dataset")
if is_main_process: dataset = make_dataset(cfg)
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset
if not is_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data. # Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py, # On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs. # using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None: if cfg.eval_freq > 0 and cfg.env is not None:
if is_main_process: logging.info("Creating env")
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
if is_main_process: logging.info("Creating policy")
logging.info("Creating policy")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
ds_meta=dataset.meta, ds_meta=dataset.meta,
) )
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
# Create processors - only provide dataset_stats if not resuming from saved processors # Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {} processor_kwargs = {}
postprocessor_kwargs = {} postprocessor_kwargs = {}
@@ -239,9 +209,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
**postprocessor_kwargs, **postprocessor_kwargs,
) )
if is_main_process: logging.info("Creating optimizer and scheduler")
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
@@ -251,18 +221,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
if is_main_process: logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if cfg.env is not None:
if cfg.env is not None: logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{dataset.num_episodes=}")
logging.info(f"{dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
num_processes = accelerator.num_processes logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
effective_bs = cfg.batch_size * num_processes
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training # create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"): if hasattr(cfg.policy, "drop_n_last_frames"):
@@ -285,13 +251,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
sampler=sampler, sampler=sampler,
pin_memory=device.type == "cuda", pin_memory=device.type == "cuda",
drop_last=False, drop_last=False,
prefetch_factor=2 if cfg.num_workers > 0 else None, prefetch_factor=2,
)
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
@@ -305,20 +265,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"dataloading_s": AverageMeter("data_s", ":.3f"), "dataloading_s": AverageMeter("data_s", ":.3f"),
} }
# Use effective batch size for proper epoch calculation in distributed training
effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker( train_tracker = MetricsTracker(
effective_batch_size, cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
accelerator=accelerator,
) )
if is_main_process: logging.info("Start offline training on a fixed dataset")
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps): for _ in range(step, cfg.steps):
start_time = time.perf_counter() start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)
@@ -331,15 +282,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
batch, batch,
optimizer, optimizer,
cfg.optimizer.grad_clip_norm, cfg.optimizer.grad_clip_norm,
accelerator=accelerator, grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp,
) )
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here. # increment `step` here.
step += 1 step += 1
train_tracker.step() train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
@@ -353,90 +305,69 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
if is_main_process: logging.info(f"Checkpoint policy after step {step}")
logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) save_checkpoint(
save_checkpoint( checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
checkpoint_dir=checkpoint_dir, )
step=step, update_last_checkpoint(checkpoint_dir)
cfg=cfg, if wandb_logger:
policy=accelerator.unwrap_model(policy), wandb_logger.log_policy(checkpoint_dir)
optimizer=optimizer,
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
accelerator.wait_for_everyone()
if cfg.env and is_eval_step: if cfg.env and is_eval_step:
if is_main_process: step_id = get_step_identifier(step, cfg.steps)
step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}")
logging.info(f"Eval policy at step {step}") with (
with torch.no_grad(), accelerator.autocast(): torch.no_grad(),
eval_info = eval_policy_all( torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
envs=eval_env, # dict[suite][task_id] -> vec_env ):
policy=accelerator.unwrap_model(policy), eval_info = eval_policy_all(
preprocessor=preprocessor, envs=eval_env, # dict[suite][task_id] -> vec_env
postprocessor=postprocessor, policy=policy,
n_episodes=cfg.eval.n_episodes, preprocessor=preprocessor,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", postprocessor=postprocessor,
max_episodes_rendered=4, n_episodes=cfg.eval.n_episodes,
start_seed=cfg.seed, videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_parallel_tasks=cfg.env.max_parallel_tasks, max_episodes_rendered=4,
) start_seed=cfg.seed,
# overall metrics (suite-agnostic) max_parallel_tasks=cfg.env.max_parallel_tasks,
aggregated = eval_info["overall"]
# optional: per-suite logging
for suite, suite_info in eval_info.items():
logging.info("Suite %s aggregated: %s", suite, suite_info)
# meters/tracker
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
accelerator=accelerator,
) )
eval_tracker.eval_s = aggregated.pop("eval_s") # overall metrics (suite-agnostic)
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") aggregated = eval_info["overall"]
eval_tracker.pc_success = aggregated.pop("pc_success")
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
accelerator.wait_for_everyone() # optional: per-suite logging
for suite, suite_info in eval_info.items():
logging.info("Suite %s aggregated: %s", suite, suite_info)
# meters/tracker
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
eval_tracker.eval_s = aggregated.pop("eval_s")
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
eval_tracker.pc_success = aggregated.pop("pc_success")
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
if eval_env: if eval_env:
close_envs(eval_env) close_envs(eval_env)
logging.info("End of training")
if is_main_process: if cfg.policy.push_to_hub:
logging.info("End of training") policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
if cfg.policy.push_to_hub: postprocessor.push_to_hub(cfg.policy.repo_id)
unwrapped_policy = accelerator.unwrap_model(policy)
unwrapped_policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
postprocessor.push_to_hub(cfg.policy.repo_id)
# Properly clean up the distributed process group
accelerator.wait_for_everyone()
accelerator.end_training()
def main(): def main():
init_logging()
train() train()

View File

@@ -270,15 +270,8 @@ class HomunculusArm(Teleoperator):
raw_values = None raw_values = None
with self.serial_lock: with self.serial_lock:
if self.serial.in_waiting > 0: if self.serial.in_waiting > 0:
lines = [] self.serial.flush()
while self.serial.in_waiting > 0: raw_values = self.serial.readline().decode("utf-8").strip().split(" ")
line = self.serial.read_until().decode("utf-8").strip()
if line:
lines.append(line.split(" "))
if lines:
raw_values = lines[-1]
if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values
continue continue

View File

@@ -304,15 +304,8 @@ class HomunculusGlove(Teleoperator):
positions = None positions = None
with self.serial_lock: with self.serial_lock:
if self.serial.in_waiting > 0: if self.serial.in_waiting > 0:
lines = [] self.serial.flush()
while self.serial.in_waiting > 0: positions = self.serial.readline().decode("utf-8").strip().split(" ")
line = self.serial.read_until().decode("utf-8").strip()
if line:
lines.append(line.split(" "))
if lines:
positions = lines[-1]
if positions is None or len(positions) != len(self.joints): if positions is None or len(positions) != len(self.joints):
continue continue

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_stretch3 import Stretch3GamePadConfig
from .stretch3_gamepad import Stretch3GamePad

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("stretch3")
@dataclass
class Stretch3GamePadConfig(TeleoperatorConfig):
"""Stretch3GamePadConfig"""

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import numpy as np
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot_params import RobotParams
from lerobot.utils.errors import DeviceAlreadyConnectedError
from ..teleoperator import Teleoperator
from .configuration_stretch3 import Stretch3GamePadConfig
# from stretch_body.gamepad_controller.GamePadController
GAMEPAD_BUTTONS = [
"middle_led_ring_button_pressed",
"left_stick_x",
"left_stick_y",
"right_stick_x",
"right_stick_y",
"left_stick_button_pressed",
"right_stick_button_pressed",
"bottom_button_pressed",
"top_button_pressed",
"left_button_pressed",
"right_button_pressed",
"left_shoulder_button_pressed",
"right_shoulder_button_pressed",
"select_button_pressed",
"start_button_pressed",
"left_trigger_pulled",
"right_trigger_pulled",
"bottom_pad_pressed",
"top_pad_pressed",
"left_pad_pressed",
"right_pad_pressed",
]
class Stretch3GamePad(Teleoperator):
"""[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
config_class = Stretch3GamePadConfig
name = "stretch3"
def __init__(self, config: Stretch3GamePadConfig):
raise NotImplementedError
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.api = GamePadTeleop(robot_instance=False)
self.is_connected = False
self.logs = {}
# TODO(aliberts): test this
RobotParams.set_logging_level("WARNING")
RobotParams.set_logging_formatter("brief_console_formatter")
@property
def action_features(self) -> dict:
return {
"dtype": "float32",
"shape": (len(GAMEPAD_BUTTONS),),
"names": {"buttons": GAMEPAD_BUTTONS},
}
@property
def feedback_features(self) -> dict:
return {}
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
)
self.api.startup()
self.api._update_state() # Check controller can be read & written
self.api._update_modes()
self.is_connected = True
def calibrate(self) -> None:
pass
def get_action(self) -> np.ndarray:
# Read Stretch state
before_read_t = time.perf_counter()
action = self.api.gamepad_controller.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
action = np.asarray(list(action.values()))
return action
def send_feedback(self, feedback: np.ndarray) -> None:
pass
def disconnect(self) -> None:
self.api.stop()
self.is_connected = False

View File

@@ -49,6 +49,14 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .so101_leader import SO101Leader from .so101_leader import SO101Leader
return SO101Leader(config) return SO101Leader(config)
elif config.type == "stretch3":
from .stretch3_gamepad import Stretch3GamePad
return Stretch3GamePad(config)
elif config.type == "widowx":
from .widowx import WidowX
return WidowX(config)
elif config.type == "mock_teleop": elif config.type == "mock_teleop":
from tests.mocks.mock_teleop import MockTeleop from tests.mocks.mock_teleop import MockTeleop

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_widowx import WidowXConfig
from .widowx import WidowX

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("widowx")
@dataclass
class WidowXConfig(TeleoperatorConfig):
port: str # Port to connect to the arm

View File

@@ -0,0 +1,155 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_widowx import WidowXConfig
logger = logging.getLogger(__name__)
class WidowX(Teleoperator):
"""
[WidowX](https://www.trossenrobotics.com/widowx-250) developed by Trossen Robotics
"""
config_class = WidowXConfig
name = "widowx"
def __init__(self, config: WidowXConfig):
raise NotImplementedError
super().__init__(config)
self.config = config
self.bus = DynamixelMotorsBus(
port=self.config.port,
motors={
"waist": Motor(1, "xm430-w350", MotorNormMode.RANGE_M100_100),
"shoulder": Motor(2, "xm430-w350", MotorNormMode.RANGE_M100_100),
"shoulder_shadow": Motor(3, "xm430-w350", MotorNormMode.RANGE_M100_100),
"elbow": Motor(4, "xm430-w350", MotorNormMode.RANGE_M100_100),
"elbow_shadow": Motor(5, "xm430-w350", MotorNormMode.RANGE_M100_100),
"forearm_roll": Motor(6, "xm430-w350", MotorNormMode.RANGE_M100_100),
"wrist_angle": Motor(7, "xm430-w350", MotorNormMode.RANGE_M100_100),
"wrist_rotate": Motor(8, "xl430-w250", MotorNormMode.RANGE_M100_100),
"gripper": Motor(9, "xc430-w150", MotorNormMode.RANGE_0_100),
},
)
@property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.bus.is_connected
def connect(self, calibrate: bool = True):
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.bus.connect()
if not self.is_calibrated and calibrate:
self.calibrate()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self) -> None:
raise NotImplementedError # TODO(aliberts): adapt code below (copied from koch)
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
self.bus.write("Drive_Mode", "elbow_flex", DriveMode.INVERTED.value)
drive_modes = {motor: 1 if motor == "elbow_flex" else 0 for motor in self.bus.motors}
input("Move robot to the middle of its range of motion and press ENTER....")
homing_offsets = self.bus.set_half_turn_homings()
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors]
print(
f"Move all joints except {full_turn_motors} sequentially through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors)
for motor in full_turn_motors:
range_mins[motor] = 0
range_maxes[motor] = 4095
self.calibration = {}
for motor, m in self.bus.motors.items():
self.calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=homing_offsets[motor],
range_min=range_mins[motor],
range_max=range_maxes[motor],
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
self.bus.disable_torque()
self.bus.configure_motors()
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
# As a result, if only one of them is required to move to a certain position,
# the other will follow. This is to avoid breaking the motors.
self.bus.write("Secondary_ID", "shoulder_shadow", 2)
self.bus.write("Secondary_ID", "elbow_shadow", 4)
def get_action(self) -> dict[str, float]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start = time.perf_counter()
action = self.bus.sync_read("Present_Position")
action = {f"{motor}.pos": val for motor, val in action.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.bus.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Callable
from typing import Any from typing import Any
from lerobot.utils.utils import format_big_number from lerobot.utils.utils import format_big_number
@@ -85,7 +84,6 @@ class MetricsTracker:
"samples", "samples",
"episodes", "episodes",
"epochs", "epochs",
"accelerator",
] ]
def __init__( def __init__(
@@ -95,7 +93,6 @@ class MetricsTracker:
num_episodes: int, num_episodes: int,
metrics: dict[str, AverageMeter], metrics: dict[str, AverageMeter],
initial_step: int = 0, initial_step: int = 0,
accelerator: Callable | None = None,
): ):
self.__dict__.update(dict.fromkeys(self.__keys__)) self.__dict__.update(dict.fromkeys(self.__keys__))
self._batch_size = batch_size self._batch_size = batch_size
@@ -109,7 +106,6 @@ class MetricsTracker:
self.samples = self.steps * self._batch_size self.samples = self.steps * self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames
self.accelerator = accelerator
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__: if name in self.__dict__:
@@ -132,7 +128,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step. Updates metrics that depend on 'step' for one step.
""" """
self.steps += 1 self.steps += 1
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) self.samples += self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
from collections.abc import Callable, Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -164,20 +164,14 @@ def set_rng_state(random_state_dict: dict[str, Any]):
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed, accelerator: Callable | None = None) -> None: def set_seed(seed) -> None:
"""Set seed for reproducibility.""" """Set seed for reproducibility."""
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
if accelerator:
from accelerate.utils import set_seed as _accelerate_set_seed
_accelerate_set_seed(seed)
@contextmanager @contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]: def seeded_context(seed: int) -> Generator[None, None, None]:

View File

@@ -27,7 +27,6 @@ from statistics import mean
import numpy as np import numpy as np
import torch import torch
from accelerate import Accelerator
from datasets.utils.logging import disable_progress_bar, enable_progress_bar from datasets.utils.logging import disable_progress_bar, enable_progress_bar
@@ -45,9 +44,6 @@ def auto_select_torch_device() -> torch.device:
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using mps.") logging.info("Metal backend detected, using mps.")
return torch.device("mps") return torch.device("mps")
elif torch.xpu.is_available():
logging.info("Intel XPU backend detected, using xpu.")
return torch.device("xpu")
else: else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu") return torch.device("cpu")
@@ -64,9 +60,6 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
case "mps": case "mps":
assert torch.backends.mps.is_available() assert torch.backends.mps.is_available()
device = torch.device("mps") device = torch.device("mps")
case "xpu":
assert torch.xpu.is_available()
device = torch.device("xpu")
case "cpu": case "cpu":
device = torch.device("cpu") device = torch.device("cpu")
if log: if log:
@@ -87,21 +80,6 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
device = device.type device = device.type
if device == "mps" and dtype == torch.float64: if device == "mps" and dtype == torch.float64:
return torch.float32 return torch.float32
if device == "xpu" and dtype == torch.float64:
if hasattr(torch.xpu, "get_device_capability"):
device_capability = torch.xpu.get_device_capability()
# NOTE: Some Intel XPU devices do not support double precision (FP64).
# The `has_fp64` flag is returned by `torch.xpu.get_device_capability()`
# when available; if False, we fall back to float32 for compatibility.
if not device_capability.get("has_fp64", False):
logging.warning(f"Device {device} does not support float64, using float32 instead.")
return torch.float32
else:
logging.warning(
f"Device {device} capability check failed. Assuming no support for float64, using float32 instead."
)
return torch.float32
return dtype
else: else:
return dtype return dtype
@@ -112,16 +90,14 @@ def is_torch_device_available(try_device: str) -> bool:
return torch.cuda.is_available() return torch.cuda.is_available()
elif try_device == "mps": elif try_device == "mps":
return torch.backends.mps.is_available() return torch.backends.mps.is_available()
elif try_device == "xpu":
return torch.xpu.is_available()
elif try_device == "cpu": elif try_device == "cpu":
return True return True
else: else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str): def is_amp_available(device: str):
if device in ["cuda", "xpu", "cpu"]: if device in ["cuda", "cpu"]:
return True return True
elif device == "mps": elif device == "mps":
return False return False
@@ -134,50 +110,36 @@ def init_logging(
display_pid: bool = False, display_pid: bool = False,
console_level: str = "INFO", console_level: str = "INFO",
file_level: str = "DEBUG", file_level: str = "DEBUG",
accelerator: Accelerator | None = None,
): ):
"""Initialize logging configuration for LeRobot.
In multi-GPU training, only the main process logs to console to avoid duplicate output.
Non-main processes have console logging suppressed but can still log to file.
Args:
log_file: Optional file path to write logs to
display_pid: Include process ID in log messages (useful for debugging multi-process)
console_level: Logging level for console output
file_level: Logging level for file output
accelerator: Optional Accelerator instance (for multi-GPU detection)
"""
def custom_format(record: logging.LogRecord) -> str: def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}" fnameline = f"{record.pathname}:{record.lineno}"
pid_str = f"[PID: {os.getpid()}] " if display_pid else ""
return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}" # NOTE: Display PID is useful for multi-process logging.
if display_pid:
pid_str = f"[PID: {os.getpid()}]"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
else:
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
return message
formatter = logging.Formatter() formatter = logging.Formatter()
formatter.format = custom_format formatter.format = custom_format
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.NOTSET) logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
# Clear any existing handlers # Remove unused default handlers
logger.handlers.clear() for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Determine if this is a non-main process in distributed training # Write logs to console
is_main_process = accelerator.is_main_process if accelerator is not None else True console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
# Console logging (main process only) console_handler.setLevel(console_level.upper())
if is_main_process: logger.addHandler(console_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
else:
# Suppress console output for non-main processes
logger.addHandler(logging.NullHandler())
logger.setLevel(logging.ERROR)
# Additionally write logs to file
if log_file is not None: if log_file is not None:
file_handler = logging.FileHandler(log_file) file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)

View File

@@ -46,7 +46,7 @@ def log_rerun_data(
This function iterates through the provided observation and action dictionaries and sends their contents This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately: to the Rerun viewer. It handles different data types appropriately:
- Scalars values (floats, ints) are logged as `rr.Scalars`. - Scalar values (floats, ints) are logged as `rr.Scalar`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format and logged as `rr.Image`. from CHW to HWC format and logged as `rr.Image`.
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
@@ -65,7 +65,7 @@ def log_rerun_data(
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}" key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v): if _is_scalar(v):
rr.log(key, rr.Scalars(float(v))) rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray): elif isinstance(v, np.ndarray):
arr = v arr = v
# Convert CHW -> HWC when needed # Convert CHW -> HWC when needed
@@ -73,7 +73,7 @@ def log_rerun_data(
arr = np.transpose(arr, (1, 2, 0)) arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1: if arr.ndim == 1:
for i, vi in enumerate(arr): for i, vi in enumerate(arr):
rr.log(f"{key}_{i}", rr.Scalars(float(vi))) rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else: else:
rr.log(key, rr.Image(arr), static=True) rr.log(key, rr.Image(arr), static=True)
@@ -84,13 +84,13 @@ def log_rerun_data(
key = k if str(k).startswith("action.") else f"action.{k}" key = k if str(k).startswith("action.") else f"action.{k}"
if _is_scalar(v): if _is_scalar(v):
rr.log(key, rr.Scalars(float(v))) rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray): elif isinstance(v, np.ndarray):
if v.ndim == 1: if v.ndim == 1:
for i, vi in enumerate(v): for i, vi in enumerate(v):
rr.log(f"{key}_{i}", rr.Scalars(float(vi))) rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else: else:
# Fall back to flattening higher-dimensional arrays # Fall back to flattening higher-dimensional arrays
flat = v.flatten() flat = v.flatten()
for i, vi in enumerate(flat): for i, vi in enumerate(flat):
rr.log(f"{key}_{i}", rr.Scalars(float(vi))) rr.log(f"{key}_{i}", rr.Scalar(float(vi)))

View File

@@ -389,7 +389,7 @@ def test_raw_observation_to_observation_device_handling():
# Check that all expected keys produce tensors (device placement handled by preprocessor later) # Check that all expected keys produce tensors (device placement handled by preprocessor later)
for key, value in observation.items(): for key, value in observation.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device" assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
def test_raw_observation_to_observation_deterministic(): def test_raw_observation_to_observation_deterministic():

View File

@@ -134,25 +134,6 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("degrees, translate", [((-5.0, 5.0), (0.05, 0.05)), ((10.0, 10.0), (0.1, 0.1))])
def test_get_image_transforms_affine(img_tensor_factory, degrees, translate):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={
"affine": ImageTransformConfig(
type="RandomAffine", kwargs={"degrees": degrees, "translate": translate}
)
},
)
tf = ImageTransforms(tf_cfg)
output = tf(img_tensor)
# Verify output shape is preserved
assert output.shape == img_tensor.shape
# Verify transform is type RandomAffine
assert isinstance(tf.transforms["affine"], v2.RandomAffine)
def test_get_image_transforms_max_num_transforms(img_tensor_factory): def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory() img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig( tf_cfg = ImageTransformsConfig(
@@ -281,37 +262,7 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
# NOTE: PyTorch versions have different randomness, it might break this test. # NOTE: PyTorch versions have different randomness, it might break this test.
# See this PR: https://github.com/huggingface/lerobot/pull/1127. # See this PR: https://github.com/huggingface/lerobot/pull/1127.
# Use config without affine to match original test artifacts cfg = ImageTransformsConfig(enable=True)
cfg = ImageTransformsConfig(
enable=True,
tfs={
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
},
)
default_tf = ImageTransforms(cfg) default_tf = ImageTransforms(cfg)
with seeded_context(1337): with seeded_context(1337):
@@ -417,7 +368,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples) save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
# Check if the transformed images exist for each transform type # Check if the transformed images exist for each transform type
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness", "affine"] transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
for transform in transforms: for transform in transforms:
transform_dir = tmp_path / transform transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created." assert transform_dir.exists(), f"{transform} directory was not created."

View File

@@ -14,17 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
from dataclasses import dataclass, field
import gymnasium as gym import gymnasium as gym
import pytest import pytest
import torch import torch
from gymnasium.envs.registration import register, registry as gym_registry
from gymnasium.utils.env_checker import check_env from gymnasium.utils.env_checker import check_env
import lerobot import lerobot
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation from lerobot.envs.utils import preprocess_observation
from tests.utils import require_env from tests.utils import require_env
@@ -68,43 +64,3 @@ def test_factory(env_name):
assert img.min() >= 0.0 assert img.min() >= 0.0
env.close() env.close()
def test_factory_custom_gym_id():
gym_id = "dummy_gym_pkg/DummyTask-v0"
if gym_id in gym_registry:
pytest.skip(f"Environment ID {gym_id} is already registered")
@EnvConfig.register_subclass("dummy")
@dataclass
class DummyEnv(EnvConfig):
task: str = "DummyTask-v0"
fps: int = 10
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def package_name(self) -> str:
return "dummy_gym_pkg"
@property
def gym_id(self) -> str:
return gym_id
@property
def gym_kwargs(self) -> dict:
return {}
try:
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
cfg = DummyEnv()
envs_dict = make_env(cfg, n_envs=1)
dummy_envs = envs_dict["dummy"]
assert len(dummy_envs) == 1
env = next(iter(dummy_envs.values()))
assert env is not None and isinstance(env, gym.vector.VectorEnv)
env.close()
finally:
if gym_id in gym_registry:
del gym_registry[gym_id]

View File

@@ -20,10 +20,8 @@ from functools import cached_property
from typing import Any from typing import Any
from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.cameras import CameraConfig, make_cameras_from_configs
from lerobot.motors.motors_bus import Motor, MotorNormMode
from lerobot.robots import Robot, RobotConfig from lerobot.robots import Robot, RobotConfig
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from tests.mocks.mock_motors_bus import MockMotorsBus
@RobotConfig.register_subclass("mock_robot") @RobotConfig.register_subclass("mock_robot")
@@ -60,21 +58,8 @@ class MockRobot(Robot):
self.config = config self.config = config
self._is_connected = False self._is_connected = False
self._is_calibrated = config.calibrated self._is_calibrated = config.calibrated
self.cameras = make_cameras_from_configs(config.cameras)
mock_motors = {}
for i in range(config.n_motors):
motor_name = f"motor_{i + 1}"
mock_motors[motor_name] = Motor(
id=i + 1,
model="model_1", # Use model_1 which exists in MockMotorsBus tables
norm_mode=MotorNormMode.RANGE_M100_100,
)
self.bus = MockMotorsBus("/dev/dummy-port", mock_motors)
# NOTE(fracapuano): The .motors attribute was used from the previous interface
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)] self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
self.cameras = make_cameras_from_configs(config.cameras)
@property @property
def _motors_ft(self) -> dict[str, type]: def _motors_ft(self) -> dict[str, type]:

View File

@@ -117,12 +117,12 @@ def test_send_interactions():
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"), services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"),
] ]
def mock_interactions_stream(): def mock_intercations_stream():
yield from list_of_interaction_messages yield from list_of_interaction_messages
return services_pb2.Empty() return services_pb2.Empty()
response = client.SendInteractions(mock_interactions_stream()) response = client.SendInteractions(mock_intercations_stream())
assert response == services_pb2.Empty() assert response == services_pb2.Empty()
close_learner_service_stub(channel, server) close_learner_service_stub(channel, server)

View File

@@ -1,211 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Multi-GPU Training Tests
This module tests multi-GPU training functionality with accelerate.
These tests are designed to run on machines with 2+ GPUs and are executed
in the nightly CI workflow.
The tests automatically generate accelerate configs and launch training
with subprocess to properly test the distributed training environment.
"""
import os
import subprocess
import tempfile
from pathlib import Path
import pytest
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def get_num_available_gpus():
"""Returns the number of available GPUs."""
if not torch.cuda.is_available():
return 0
return torch.cuda.device_count()
def download_dataset(repo_id, episodes):
"""
Pre-download dataset to avoid race conditions in multi-GPU training.
Args:
repo_id: HuggingFace dataset repository ID
episodes: List of episode indices to download
"""
# Simply instantiating the dataset will download it
_ = LeRobotDataset(repo_id, episodes=episodes)
print(f"Dataset {repo_id} downloaded successfully")
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
"""
Helper function to run training with accelerate launch.
Args:
config_args: List of config arguments to pass to lerobot_train.py
num_processes: Number of processes (GPUs) to use
temp_dir: Temporary directory for outputs
Returns:
subprocess.CompletedProcess result
"""
config_path = Path(temp_dir) / "accelerate_config.yaml"
# Write YAML config
with open(config_path, "w") as f:
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: MULTI_GPU\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("downcast_bf16: 'no'\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
cmd = [
"accelerate",
"launch",
"--config_file",
str(config_path),
"-m",
"lerobot.scripts.lerobot_train",
] + config_args
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))},
)
return result
@pytest.mark.skipif(
get_num_available_gpus() < 2,
reason="Multi-GPU tests require at least 2 GPUs",
)
class TestMultiGPUTraining:
"""Test suite for multi-GPU training functionality."""
def test_basic_multi_gpu_training(self):
"""
Test that basic multi-GPU training runs successfully.
Verifies that the training completes without errors.
"""
# Pre-download dataset to avoid race conditions
download_dataset("lerobot/pusht", episodes=[0])
with tempfile.TemporaryDirectory() as temp_dir:
output_dir = Path(temp_dir) / "outputs"
config_args = [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--policy.push_to_hub=false",
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=10",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
"--num_workers=0",
]
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
# Check that training completed successfully
assert result.returncode == 0, (
f"Multi-GPU training failed with return code {result.returncode}\n"
f"STDOUT:\n{result.stdout}\n"
f"STDERR:\n{result.stderr}"
)
# Verify checkpoint was saved
checkpoints_dir = output_dir / "checkpoints"
assert checkpoints_dir.exists(), "Checkpoints directory was not created"
# Verify that training completed
assert "End of training" in result.stdout or "End of training" in result.stderr
def test_checkpoint_saving_multi_gpu(self):
"""
Test that checkpoints are correctly saved during multi-GPU training.
Only the main process (rank 0) should save checkpoints.
"""
# Pre-download dataset to avoid race conditions
download_dataset("lerobot/pusht", episodes=[0])
with tempfile.TemporaryDirectory() as temp_dir:
output_dir = Path(temp_dir) / "outputs"
config_args = [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--policy.push_to_hub=false",
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=20",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
"--num_workers=0",
]
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
assert result.returncode == 0, (
f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
)
# Verify checkpoint directory exists
checkpoints_dir = output_dir / "checkpoints"
assert checkpoints_dir.exists(), "Checkpoints directory not created"
# Count checkpoint directories (should have checkpoint at step 10 and 20)
checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()]
assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}"
# Verify checkpoint contents
for checkpoint_dir in checkpoint_dirs:
# Check for model files
model_files = list(checkpoint_dir.rglob("*.safetensors"))
assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}"
# Check for training state
training_state_dir = checkpoint_dir / "training_state"
assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}"
# Verify optimizer state exists
optimizer_state = training_state_dir / "optimizer_state.safetensors"
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"

View File

@@ -45,7 +45,7 @@ def mock_rerun(monkeypatch):
calls.append((key, obj, kwargs)) calls.append((key, obj, kwargs))
dummy_rr = SimpleNamespace( dummy_rr = SimpleNamespace(
Scalars=DummyScalar, Scalar=DummyScalar,
Image=DummyImage, Image=DummyImage,
log=dummy_log, log=dummy_log,
init=lambda *a, **k: None, init=lambda *a, **k: None,
@@ -109,9 +109,9 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
vu.log_rerun_data(observation=obs_data, action=action_data) vu.log_rerun_data(observation=obs_data, action=action_data)
# We expect: # We expect:
# - observation.state.temperature -> Scalars # - observation.state.temperature -> Scalar
# - observation.camera -> Image (HWC) with static=True # - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars # - action.throttle -> Scalar
# - action.vector_0, action.vector_1 -> Scalars # - action.vector_0, action.vector_1 -> Scalars
expected_keys = { expected_keys = {
f"{OBS_STATE}.temperature", f"{OBS_STATE}.temperature",