mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
Compare commits
3 Commits
feat/dummy
...
feat/add_m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
224be5be9a | ||
|
|
67269e33a5 | ||
|
|
66936f278f |
25
.github/workflows/fast_tests.yml
vendored
25
.github/workflows/fast_tests.yml
vendored
@@ -57,7 +57,11 @@ jobs:
|
||||
# It runs everytime we commit to a PR or push to main
|
||||
fast-pytest-tests:
|
||||
name: Fast Pytest Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
@@ -67,12 +71,21 @@ jobs:
|
||||
lfs: true
|
||||
|
||||
# TODO(Steven): Evaluate the need of these dependencies
|
||||
- name: Install apt dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential git \
|
||||
curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
# Add ffmpeg@7 paths for subsequent steps
|
||||
echo "PATH=/opt/homebrew/opt/ffmpeg@7/bin:$PATH" >> $GITHUB_ENV
|
||||
echo "LDFLAGS=-L/opt/homebrew/opt/ffmpeg@7/lib" >> $GITHUB_ENV
|
||||
echo "CPPFLAGS=-I/opt/homebrew/opt/ffmpeg@7/include" >> $GITHUB_ENV
|
||||
echo "PKG_CONFIG_PATH=/opt/homebrew/opt/ffmpeg@7/lib/pkgconfig" >> $GITHUB_ENV
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
|
||||
21
.github/workflows/full_tests.yml
vendored
21
.github/workflows/full_tests.yml
vendored
@@ -51,7 +51,11 @@ jobs:
|
||||
# It runs everytime a PR is approved or a push to main
|
||||
full-tests:
|
||||
name: Full Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
if: |
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved') ||
|
||||
github.event_name == 'push' ||
|
||||
@@ -64,11 +68,16 @@ jobs:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
@@ -78,7 +87,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
34
.github/workflows/nightly.yml
vendored
34
.github/workflows/nightly.yml
vendored
@@ -119,7 +119,6 @@ jobs:
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
@@ -159,36 +158,3 @@ jobs:
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job runs multi-GPU training tests with 4 GPUs
|
||||
nightly-multi-gpu-tests:
|
||||
name: Nightly Multi-GPU Tests
|
||||
needs: [build-docker-gpu-nightly]
|
||||
runs-on:
|
||||
group: aws-g4dn-12xlarge # Instance with 4 GPUs
|
||||
env:
|
||||
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
|
||||
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -82,14 +82,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
@@ -111,7 +103,7 @@ jobs:
|
||||
- name: Publish to TestPyPI for pre-releases
|
||||
# True for tags like 'v0.2.0-rc1'
|
||||
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
verbose: true
|
||||
@@ -119,7 +111,7 @@ jobs:
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
verbose: true
|
||||
print-hash: true
|
||||
@@ -128,7 +120,11 @@ jobs:
|
||||
test-release:
|
||||
name: Test Release
|
||||
needs: [build-and-publish]
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
@@ -138,15 +134,20 @@ jobs:
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
- name: Install apt dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
enable-cache: true # zizmor: ignore[cache-poisoning]
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
- name: Create uv virtual environment
|
||||
|
||||
12
.github/workflows/stale.yml
vendored
12
.github/workflows/stale.yml
vendored
@@ -27,17 +27,15 @@ env:
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
This PR was closed because it has been stalled for 14 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
This PR has been automatically marked as stale because it has not had
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this PR will reset this count.
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Thank you for your contributions.
|
||||
|
||||
jobs:
|
||||
@@ -58,10 +56,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
|
||||
days-before-issue-close: 14
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
days-before-pr-stale: 180
|
||||
days-before-pr-close: 14
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
21
.github/workflows/unbound_deps_tests.yml
vendored
21
.github/workflows/unbound_deps_tests.yml
vendored
@@ -42,7 +42,11 @@ jobs:
|
||||
# This job runs the E2E tests + pytest with all unbound extras
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
@@ -51,11 +55,16 @@ jobs:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
@@ -70,7 +79,7 @@ jobs:
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
|
||||
##### General Code Quality & Formatting #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1024']
|
||||
@@ -39,20 +39,20 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.1
|
||||
rev: v0.12.4
|
||||
hooks:
|
||||
- id: ruff-format
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.38.1
|
||||
rev: v1.34.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.21.0
|
||||
rev: v3.20.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
@@ -68,12 +68,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.28.0
|
||||
rev: v8.27.2
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.15.2
|
||||
rev: v1.11.0
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
@@ -87,7 +87,7 @@ repos:
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
rev: v1.16.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
|
||||
@@ -137,7 +137,7 @@ Follow these steps to start contributing:
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda:
|
||||
Set up a development environment with conda or miniconda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
|
||||
19
README.md
19
README.md
@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -185,11 +185,6 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
### Weights & Biases
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
@@ -212,13 +207,13 @@ lerobot-dataset-viz \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--mode local \
|
||||
--local-files-only 1 \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
@@ -315,7 +310,7 @@ To upload these to the hub, run the following:
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
@@ -342,3 +337,7 @@ If you want, you can cite this work with:
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
@@ -37,12 +35,8 @@
|
||||
title: π₀ (Pi0)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: envhub
|
||||
title: Environments from the Hub
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: libero
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
# Loading Environments from the Hub
|
||||
|
||||
The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community.
|
||||
|
||||
## Overview
|
||||
|
||||
With EnvHub, you can:
|
||||
|
||||
- Load environments from the Hub instantly
|
||||
- Share your custom simulation tasks with the community
|
||||
- Version control your environments using Git
|
||||
- Distribute complex physics simulations without packaging hassles
|
||||
|
||||
## Quick Start
|
||||
|
||||
Loading an environment from the Hub is as simple as:
|
||||
|
||||
```python
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load a hub environment (requires explicit consent to run remote code)
|
||||
env = make_env("lerobot/cartpole-env", trust_remote_code=True)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
**Security Notice**: Loading environments from the Hub executes Python code
|
||||
from third-party repositories. Only use `trust_remote_code=True` with
|
||||
repositories you trust. We strongly recommend pinning to a specific commit
|
||||
hash for reproducibility and security.
|
||||
</Tip>
|
||||
|
||||
## What is EnvHub?
|
||||
|
||||
EnvHub is a framework that allows researchers and developers to:
|
||||
|
||||
1. **Publish environments** to the Hugging Face Hub as Git repositories
|
||||
2. **Load environments** dynamically without installing them as packages
|
||||
3. **Version and track** environment changes using Git semantics
|
||||
4. **Discover** new simulation tasks shared by the community
|
||||
|
||||
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures.
|
||||
|
||||
## Repository Structure
|
||||
|
||||
To make your environment loadable from the Hub, your repository must contain at minimum:
|
||||
|
||||
### Required Files
|
||||
|
||||
**`env.py`** (or custom Python file)
|
||||
|
||||
- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
|
||||
- This function should return one of:
|
||||
- A `gym.vector.VectorEnv` (most common)
|
||||
- A single `gym.Env` (will be automatically wrapped)
|
||||
- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks)
|
||||
|
||||
### Optional Files
|
||||
|
||||
**`requirements.txt`**
|
||||
|
||||
- List any additional dependencies your environment needs
|
||||
- Users will need to install these manually before loading your environment
|
||||
|
||||
**`README.md`**
|
||||
|
||||
- Document your environment: what task it implements, observation/action spaces, rewards, etc.
|
||||
- Include usage examples and any special setup instructions
|
||||
|
||||
**`.gitignore`**
|
||||
|
||||
- Exclude unnecessary files from your repository
|
||||
|
||||
### Example Repository Structure
|
||||
|
||||
```
|
||||
my-environment-repo/
|
||||
├── env.py # Main environment definition (required)
|
||||
├── requirements.txt # Dependencies (optional)
|
||||
├── README.md # Documentation (recommended)
|
||||
├── assets/ # Images, videos, etc. (optional)
|
||||
│ └── demo.gif
|
||||
└── configs/ # Config files if needed (optional)
|
||||
└── task_config.yaml
|
||||
```
|
||||
|
||||
## Creating Your Environment Repository
|
||||
|
||||
### Step 1: Define Your Environment
|
||||
|
||||
Create an `env.py` file with a `make_env` function:
|
||||
|
||||
```python
|
||||
# env.py
|
||||
import gymnasium as gym
|
||||
|
||||
def make_env(n_envs: int = 1, use_async_envs: bool = False):
|
||||
"""
|
||||
Create vectorized environments for your custom task.
|
||||
|
||||
Args:
|
||||
n_envs: Number of parallel environments
|
||||
use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv or dict mapping suite names to vectorized envs
|
||||
"""
|
||||
def _make_single_env():
|
||||
# Create your custom environment
|
||||
return gym.make("CartPole-v1")
|
||||
|
||||
# Choose vector environment type
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
# Create vectorized environment
|
||||
vec_env = env_cls([_make_single_env for _ in range(n_envs)])
|
||||
|
||||
return vec_env
|
||||
```
|
||||
|
||||
### Step 2: Test Locally
|
||||
|
||||
Before uploading, test your environment locally:
|
||||
|
||||
```python
|
||||
from lerobot.envs.utils import _load_module_from_path, _call_make_env, _normalize_hub_result
|
||||
|
||||
# Load your module
|
||||
module = _load_module_from_path("./env.py")
|
||||
|
||||
# Test the make_env function
|
||||
result = _call_make_env(module, n_envs=2, use_async_envs=False)
|
||||
normalized = _normalize_hub_result(result)
|
||||
|
||||
# Verify it works
|
||||
suite_name = next(iter(normalized))
|
||||
env = normalized[suite_name][0]
|
||||
obs, info = env.reset()
|
||||
print(f"Observation shape: {obs.shape if hasattr(obs, 'shape') else type(obs)}")
|
||||
env.close()
|
||||
```
|
||||
|
||||
### Step 3: Upload to the Hub
|
||||
|
||||
Upload your repository to Hugging Face:
|
||||
|
||||
```bash
|
||||
# Install huggingface_hub if needed
|
||||
pip install huggingface_hub
|
||||
|
||||
# Login to Hugging Face
|
||||
huggingface-cli login
|
||||
|
||||
# Create a new repository
|
||||
huggingface-cli repo create my-custom-env --type space --org my-org
|
||||
|
||||
# Initialize git and push
|
||||
git init
|
||||
git add .
|
||||
git commit -m "Initial environment implementation"
|
||||
git remote add origin https://huggingface.co/my-org/my-custom-env
|
||||
git push -u origin main
|
||||
```
|
||||
|
||||
Alternatively, use the `huggingface_hub` Python API:
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create repository
|
||||
api.create_repo("my-custom-env", repo_type="space")
|
||||
|
||||
# Upload files
|
||||
api.upload_folder(
|
||||
folder_path="./my-env-folder",
|
||||
repo_id="username/my-custom-env",
|
||||
repo_type="space",
|
||||
)
|
||||
```
|
||||
|
||||
## Loading Environments from the Hub
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load from the hub
|
||||
envs_dict = make_env(
|
||||
"username/my-custom-env",
|
||||
n_envs=4,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Access the environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
env = envs_dict[suite_name][0]
|
||||
|
||||
# Use it like any gym environment
|
||||
obs, info = env.reset()
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
```
|
||||
|
||||
### Advanced: Pinning to Specific Versions
|
||||
|
||||
For reproducibility and security, pin to a specific Git revision:
|
||||
|
||||
```python
|
||||
# Pin to a specific branch
|
||||
env = make_env("username/my-env@main", trust_remote_code=True)
|
||||
|
||||
# Pin to a specific commit (recommended for papers/experiments)
|
||||
env = make_env("username/my-env@abc123def456", trust_remote_code=True)
|
||||
|
||||
# Pin to a tag
|
||||
env = make_env("username/my-env@v1.0.0", trust_remote_code=True)
|
||||
```
|
||||
|
||||
### Custom File Paths
|
||||
|
||||
If your environment definition is not in `env.py`:
|
||||
|
||||
```python
|
||||
# Load from a custom file
|
||||
env = make_env("username/my-env:custom_env.py", trust_remote_code=True)
|
||||
|
||||
# Combine with version pinning
|
||||
env = make_env("username/my-env@v1.0:envs/task_a.py", trust_remote_code=True)
|
||||
```
|
||||
|
||||
### Async Environments
|
||||
|
||||
For better performance with multiple environments:
|
||||
|
||||
```python
|
||||
envs_dict = make_env(
|
||||
"username/my-env",
|
||||
n_envs=8,
|
||||
use_async_envs=True, # Use AsyncVectorEnv for parallel execution
|
||||
trust_remote_code=True
|
||||
)
|
||||
```
|
||||
|
||||
## URL Format Reference
|
||||
|
||||
The hub URL format supports several patterns:
|
||||
|
||||
| Pattern | Description | Example |
|
||||
| -------------------- | ------------------------------ | -------------------------------------- |
|
||||
| `user/repo` | Load `env.py` from main branch | `make_env("lerobot/pusht-env")` |
|
||||
| `user/repo@revision` | Load from specific revision | `make_env("lerobot/pusht-env@main")` |
|
||||
| `user/repo:path` | Load custom file | `make_env("lerobot/envs:pusht.py")` |
|
||||
| `user/repo@rev:path` | Revision + custom file | `make_env("lerobot/envs@v1:pusht.py")` |
|
||||
|
||||
## Multi-Task Environments
|
||||
|
||||
For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
|
||||
|
||||
```python
|
||||
def make_env(n_envs: int = 1, use_async_envs: bool = False):
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
# Return dict: {suite_name: {task_id: VectorEnv}}
|
||||
return {
|
||||
"suite_1": {
|
||||
0: env_cls([lambda: gym.make("Task1-v0") for _ in range(n_envs)]),
|
||||
1: env_cls([lambda: gym.make("Task2-v0") for _ in range(n_envs)]),
|
||||
},
|
||||
"suite_2": {
|
||||
0: env_cls([lambda: gym.make("Task3-v0") for _ in range(n_envs)]),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
<Tip warning={true}>
|
||||
**Important**: The `trust_remote_code=True` flag is required to execute
|
||||
environment code from the Hub. This is by design for security.
|
||||
</Tip>
|
||||
|
||||
When loading environments from the Hub:
|
||||
|
||||
1. **Review the code first**: Visit the repository and inspect `env.py` before loading
|
||||
2. **Pin to commits**: Use specific commit hashes for reproducibility
|
||||
3. **Check dependencies**: Review `requirements.txt` for suspicious packages
|
||||
4. **Use trusted sources**: Prefer official organizations or well-known researchers
|
||||
5. **Sandbox if needed**: Run untrusted code in isolated environments (containers, VMs)
|
||||
|
||||
Example of safe usage:
|
||||
|
||||
```python
|
||||
# ❌ BAD: Loading without inspection
|
||||
env = make_env("random-user/untrusted-env", trust_remote_code=True)
|
||||
|
||||
# ✅ GOOD: Review code, then pin to specific commit
|
||||
# 1. Visit https://huggingface.co/trusted-org/verified-env
|
||||
# 2. Review the env.py file
|
||||
# 3. Copy the commit hash
|
||||
env = make_env("trusted-org/verified-env@a1b2c3d4", trust_remote_code=True)
|
||||
```
|
||||
|
||||
## Example: CartPole from the Hub
|
||||
|
||||
Here's a complete example using the reference CartPole environment:
|
||||
|
||||
```python
|
||||
from lerobot.envs.factory import make_env
|
||||
import numpy as np
|
||||
|
||||
# Load the environment
|
||||
envs_dict = make_env("lerobot/cartpole-env", n_envs=4, trust_remote_code=True)
|
||||
|
||||
# Get the vectorized environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
env = envs_dict[suite_name][0]
|
||||
|
||||
# Run a simple episode
|
||||
obs, info = env.reset()
|
||||
done = np.zeros(env.num_envs, dtype=bool)
|
||||
total_reward = np.zeros(env.num_envs)
|
||||
|
||||
while not done.all():
|
||||
# Random policy
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated | truncated
|
||||
|
||||
print(f"Average reward: {total_reward.mean():.2f}")
|
||||
env.close()
|
||||
```
|
||||
|
||||
## Benefits of EnvHub
|
||||
|
||||
### For Environment Authors
|
||||
|
||||
- **Easy distribution**: No PyPI packaging required
|
||||
- **Version control**: Use Git for environment versioning
|
||||
- **Rapid iteration**: Push updates instantly
|
||||
- **Documentation**: Hub README renders beautifully
|
||||
- **Community**: Reach LeRobot users directly
|
||||
|
||||
### For Researchers
|
||||
|
||||
- **Quick experiments**: Load any environment in one line
|
||||
- **Reproducibility**: Pin to specific commits
|
||||
- **Discovery**: Browse environments on the Hub
|
||||
- **No conflicts**: No need to install conflicting packages
|
||||
|
||||
### For the Community
|
||||
|
||||
- **Growing ecosystem**: More diverse simulation tasks
|
||||
- **Standardization**: Common `make_env` API
|
||||
- **Collaboration**: Fork and improve existing environments
|
||||
- **Accessibility**: Lower barrier to sharing research
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Refusing to execute remote code"
|
||||
|
||||
You must explicitly pass `trust_remote_code=True`:
|
||||
|
||||
```python
|
||||
env = make_env("user/repo", trust_remote_code=True)
|
||||
```
|
||||
|
||||
### "Module X not found"
|
||||
|
||||
The hub environment has dependencies you need to install:
|
||||
|
||||
```bash
|
||||
# Check the repo's requirements.txt and install dependencies
|
||||
pip install gymnasium numpy
|
||||
```
|
||||
|
||||
### "make_env not found in module"
|
||||
|
||||
Your `env.py` must expose a `make_env` function:
|
||||
|
||||
```python
|
||||
def make_env(n_envs: int, use_async_envs: bool):
|
||||
# Your implementation
|
||||
pass
|
||||
```
|
||||
|
||||
### Environment returns wrong type
|
||||
|
||||
The `make_env` function must return:
|
||||
|
||||
- A `gym.vector.VectorEnv`, or
|
||||
- A single `gym.Env`, or
|
||||
- A dict `{suite_name: {task_id: VectorEnv}}`
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Document your environment**: Include observation/action space descriptions, reward structure, and termination conditions in your README
|
||||
2. **Add requirements.txt**: List all dependencies with versions
|
||||
3. **Test thoroughly**: Verify your environment works locally before pushing
|
||||
4. **Use semantic versioning**: Tag releases with version numbers
|
||||
5. **Add examples**: Include usage examples in your README
|
||||
6. **Keep it simple**: Minimize dependencies when possible
|
||||
7. **License your work**: Add a LICENSE file to clarify usage terms
|
||||
|
||||
## Future Directions
|
||||
|
||||
The EnvHub ecosystem enables exciting possibilities:
|
||||
|
||||
- **GPU-accelerated physics**: Share Isaac Gym or Brax environments
|
||||
- **Photorealistic rendering**: Distribute environments with advanced graphics
|
||||
- **Multi-agent scenarios**: Complex interaction tasks
|
||||
- **Real-world simulators**: Digital twins of physical setups
|
||||
- **Procedural generation**: Infinite task variations
|
||||
- **Domain randomization**: Pre-configured DR pipelines
|
||||
|
||||
As more researchers and developers contribute, the diversity and quality of available environments will grow, benefiting the entire robotics learning community.
|
||||
|
||||
## See Also
|
||||
|
||||
- [Hugging Face Hub Documentation](https://huggingface.co/docs/hub/en/index)
|
||||
- [Gymnasium Documentation](https://gymnasium.farama.org/index.html)
|
||||
- [Example Hub Environment](https://huggingface.co/lerobot/cartpole-env)
|
||||
@@ -1,125 +0,0 @@
|
||||
# GR00T N1.5 Policy
|
||||
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
|
||||
## Model Overview
|
||||
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
|
||||
|
||||
- Real captured data from robots.
|
||||
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
|
||||
- Internet-scale video data.
|
||||
|
||||
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
3. Install LeRobot by running:
|
||||
|
||||
```bash
|
||||
pip install lerobot[groot]
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base GR00T model on your own dataset:
|
||||
|
||||
```bash
|
||||
# Using a multi-GPU setup
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=$NUM_GPUS \
|
||||
$(which lerobot-train) \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--save_checkpoint=true \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--steps=$NUM_STEPS \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--log_freq=$LOG_FREQ \
|
||||
--policy.push_to_hub=true \
|
||||
--policy.type=groot \
|
||||
--policy.repo_id=$REPO_ID \
|
||||
--policy.tune_diffusion_model=false \
|
||||
--dataset.repo_id=$DATASET_ID \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--job_name=$JOB_NAME
|
||||
```
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
> [!NOTE]
|
||||
> Follow our instructions for Libero usage: [Libero](./libero)
|
||||
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
--robot.id=bimanual_follower \
|
||||
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
|
||||
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).
|
||||
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(hf auth whoami | head -n 1)
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
# Installation
|
||||
|
||||
## Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
```
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
@@ -21,7 +14,7 @@ Then activate your conda environment, you have to do this each time you open a s
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -81,9 +74,6 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
@@ -208,36 +208,34 @@ LeRobot supports saving and loading calibration data automatically. This is usef
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
> @property
|
||||
> def is_calibrated(self) -> bool:
|
||||
> return True
|
||||
>
|
||||
> def calibrate(self) -> None:
|
||||
> pass
|
||||
> ```
|
||||
|
||||
### `is_calibrated`
|
||||
|
||||
This should reflect whether your robot has the required calibration loaded.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```
|
||||
<!-- prettier-ignore-end -->python
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus.is_calibrated
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### `calibrate()`
|
||||
|
||||
The goal of the calibration is twofold:
|
||||
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
|
||||
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
|
||||
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
def calibrate(self) -> None:
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
# Multi-GPU Training
|
||||
|
||||
This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate).
|
||||
|
||||
## Installation
|
||||
|
||||
First, ensure you have accelerate installed:
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
## Training with Multiple GPUs
|
||||
|
||||
You can launch training in two ways:
|
||||
|
||||
### Option 1: Without config (specify parameters directly)
|
||||
|
||||
You can specify all parameters directly in the command without running `accelerate config`:
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=2 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
**Key accelerate parameters:**
|
||||
|
||||
- `--multi_gpu`: Enable multi-GPU training
|
||||
- `--num_processes=2`: Number of GPUs to use
|
||||
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
|
||||
|
||||
### Option 2: Using accelerate config
|
||||
|
||||
If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings:
|
||||
|
||||
- Compute environment: This machine
|
||||
- Number of machines: 1
|
||||
- Number of processes: (number of GPUs you want to use)
|
||||
- GPU ids to use: (leave empty to use all)
|
||||
- Mixed precision: fp16 or bf16 (recommended for faster training)
|
||||
|
||||
Then launch training with:
|
||||
|
||||
```bash
|
||||
accelerate launch $(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
When you launch training with accelerate:
|
||||
|
||||
1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate
|
||||
2. **Data distribution**: Your batch is automatically split across GPUs
|
||||
3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation
|
||||
4. **Single process logging**: Only the main process logs to wandb and saves checkpoints
|
||||
|
||||
## Learning Rate and Training Steps Scaling
|
||||
|
||||
**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters.
|
||||
|
||||
### Why No Automatic Scaling?
|
||||
|
||||
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
|
||||
However, LeRobot keeps the learning rate exactly as you specify it.
|
||||
|
||||
### When and How to Scale
|
||||
|
||||
If you want to scale your hyperparameters when using multiple GPUs, you should do it manually:
|
||||
|
||||
**Learning Rate Scaling:**
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with linear LR scaling
|
||||
# Base LR: 1e-4, with 2 GPUs -> 2e-4
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--optimizer.lr=2e-4 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
**Training Steps Scaling:**
|
||||
|
||||
Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with effective batch size 2x larger
|
||||
# Original: batch_size=8, steps=100000
|
||||
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--batch_size=8 \
|
||||
--steps=50000 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
|
||||
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
|
||||
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
|
||||
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
|
||||
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
|
||||
|
||||
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||
@@ -28,11 +28,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
## Repository
|
||||
|
||||
Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{gr00tn1_2025,
|
||||
archivePrefix = {arxiv},
|
||||
eprint = {2503.14734},
|
||||
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
|
||||
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
|
||||
month = {March},
|
||||
year = {2025},
|
||||
booktitle = {ArXiv Preprint},
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
@@ -132,15 +132,17 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
# RTC Profiling Guide
|
||||
|
||||
This guide explains how to profile RTC (Real-Time Chunking) performance to identify bottlenecks and understand why RTC might be slower than expected.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Profile with Real Robot (Profiled Version)
|
||||
|
||||
Use `eval_with_real_robot_profiled.py` to profile actual robot execution:
|
||||
|
||||
```bash
|
||||
# With RTC enabled
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=30
|
||||
|
||||
# Without RTC for comparison
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
**Output**: At the end of execution, you'll see a detailed breakdown of timing for each component:
|
||||
- `get_actions.policy_inference` - Time spent in policy inference
|
||||
- `get_actions.preprocessing` - Time spent preprocessing observations
|
||||
- `get_actions.postprocessing` - Time spent postprocessing actions
|
||||
- `get_actions.action_queue_merge` - Time spent merging actions with RTC
|
||||
- `robot.get_observation` - Time to get observations from robot
|
||||
- `robot.send_action` - Time to send actions to robot
|
||||
- And more...
|
||||
|
||||
### 2. Profile Without Robot (Comparison Script)
|
||||
|
||||
Use `profile_rtc_comparison.py` to profile just the policy inference without needing a robot:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
**Output**: Side-by-side comparison of performance with and without RTC, including:
|
||||
- Mean/min/max inference times
|
||||
- Throughput (iterations per second)
|
||||
- Verdict on whether RTC is faster or slower
|
||||
|
||||
### 3. Enable Detailed Method-Level Profiling
|
||||
|
||||
For even more granular profiling, add the `--enable_detailed_profiling` flag:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20 \
|
||||
--enable_detailed_profiling
|
||||
```
|
||||
|
||||
This will show timing for individual methods within the policy.
|
||||
|
||||
## Understanding the Output
|
||||
|
||||
### Key Metrics to Look At
|
||||
|
||||
1. **get_actions.policy_inference** - This should be the largest component
|
||||
- If RTC is enabled, this includes the RTC guidance overhead
|
||||
- Compare this with/without RTC to see the overhead
|
||||
|
||||
2. **get_actions.preprocessing** - Image preprocessing and normalization
|
||||
- Should be relatively fast
|
||||
- If slow, consider optimizing image processing
|
||||
|
||||
3. **get_actions.postprocessing** - Action denormalization
|
||||
- Should be minimal
|
||||
- If slow, check postprocessor implementation
|
||||
|
||||
4. **get_actions.action_queue_merge** - RTC-specific merging logic
|
||||
- Only present when RTC is enabled
|
||||
- If this is taking significant time, the RTC algorithm may need optimization
|
||||
|
||||
5. **robot.get_observation** - Robot communication overhead
|
||||
- If slow, check camera/sensor latency
|
||||
- Consider reducing image resolution
|
||||
|
||||
6. **robot.send_action** - Action execution overhead
|
||||
- Should be very fast
|
||||
- If slow, check robot communication
|
||||
|
||||
### Expected Performance
|
||||
|
||||
For a typical Pi0 policy on Apple Silicon (MPS):
|
||||
- **Without RTC**: ~100-200ms per inference
|
||||
- **With RTC**: Should be similar or slightly faster due to action reuse
|
||||
- **Preprocessing**: ~5-20ms depending on number of cameras
|
||||
- **Postprocessing**: ~1-5ms
|
||||
|
||||
If RTC is significantly slower, likely causes:
|
||||
1. **RTC overhead exceeds benefits** - The guidance computation is expensive
|
||||
2. **Execution horizon too small** - Not reusing enough actions to amortize overhead
|
||||
3. **No compilation** - Try with `--use_torch_compile`
|
||||
4. **Large prev_actions buffer** - Copying/processing previous actions is slow
|
||||
|
||||
## Profiling Your Own Code
|
||||
|
||||
### Using the Profiling Decorator
|
||||
|
||||
Add profiling to your own methods:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method, enable_profiling, print_profiling_summary
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
|
||||
# Decorate methods you want to profile
|
||||
@profile_method
|
||||
def my_slow_function(x):
|
||||
# ... your code ...
|
||||
return result
|
||||
|
||||
# At end of execution
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
### Using Profile Context Manager
|
||||
|
||||
For profiling specific code blocks:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_section, enable_profiling
|
||||
|
||||
enable_profiling()
|
||||
|
||||
with profile_section("data_loading"):
|
||||
data = load_data()
|
||||
|
||||
with profile_section("model_inference"):
|
||||
output = model(data)
|
||||
```
|
||||
|
||||
### Adding Profiling to Policy Methods
|
||||
|
||||
To profile specific parts of the Pi0 policy, you can add decorators:
|
||||
|
||||
```python
|
||||
# In src/lerobot/policies/pi0/modeling_pi0.py
|
||||
from lerobot.utils.profiling import profile_method, profile_section
|
||||
|
||||
class Pi0Policy:
|
||||
@profile_method
|
||||
def predict_action_chunk(self, obs, inference_delay=0, prev_chunk_left_over=None):
|
||||
# ... existing code ...
|
||||
pass
|
||||
|
||||
def _generate_actions_with_rtc(self, ...):
|
||||
with profile_section("rtc.guidance_computation"):
|
||||
# ... guidance code ...
|
||||
pass
|
||||
|
||||
with profile_section("rtc.action_merging"):
|
||||
# ... merging code ...
|
||||
pass
|
||||
```
|
||||
|
||||
## Analyzing Results
|
||||
|
||||
### Comparison Checklist
|
||||
|
||||
When comparing RTC vs non-RTC performance, check:
|
||||
|
||||
- [ ] Is `policy_inference` time higher with RTC?
|
||||
- [ ] Is `action_queue_merge` taking significant time?
|
||||
- [ ] Are you running enough iterations to amortize warmup?
|
||||
- [ ] Is torch.compile enabled for fair comparison?
|
||||
- [ ] Is the execution horizon large enough? (should be >= 10-20)
|
||||
- [ ] Are you testing on the same hardware/device?
|
||||
|
||||
### Common Bottlenecks
|
||||
|
||||
1. **Image preprocessing dominates**
|
||||
- Solution: Reduce image resolution, use fewer cameras, or optimize preprocessing
|
||||
|
||||
2. **Action queue operations are slow**
|
||||
- Solution: Review queue implementation, consider using ring buffer
|
||||
|
||||
3. **RTC guidance is expensive**
|
||||
- Solution: Reduce guidance weight, simplify guidance computation, use torch.compile
|
||||
|
||||
4. **Robot communication is slow**
|
||||
- Solution: Increase baud rate, reduce action frequency, optimize protocol
|
||||
|
||||
5. **Memory allocation overhead**
|
||||
- Solution: Pre-allocate buffers, reuse tensors, avoid unnecessary copies
|
||||
|
||||
## Advanced: Adding Custom Metrics
|
||||
|
||||
You can add custom timing metrics to the profiled script:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import record_timing
|
||||
|
||||
start = time.perf_counter()
|
||||
# ... your code ...
|
||||
duration = time.perf_counter() - start
|
||||
record_timing("my_custom_metric", duration)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Profiling shows RTC is slower by >50%
|
||||
|
||||
1. Check if torch.compile is enabled: `--use_torch_compile`
|
||||
2. Increase execution horizon: `--rtc.execution_horizon=30`
|
||||
3. Verify inference_delay is calculated correctly
|
||||
4. Profile with `--enable_detailed_profiling` to find exact bottleneck
|
||||
|
||||
### Profiling output is empty
|
||||
|
||||
1. Make sure profiling is enabled with `enable_profiling()`
|
||||
2. Verify you're running enough iterations (at least 10)
|
||||
3. Check that code is actually executing (not short-circuited)
|
||||
|
||||
### Inconsistent results between runs
|
||||
|
||||
1. Run more iterations: `--num_iterations=100`
|
||||
2. Increase warmup iterations
|
||||
3. Check for thermal throttling on device
|
||||
4. Ensure no other processes competing for resources
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Run both profiling scripts (with/without robot)
|
||||
2. Compare timing breakdowns
|
||||
3. Identify the largest bottleneck
|
||||
4. Focus optimization efforts on that component
|
||||
5. Re-run profiling to verify improvements
|
||||
|
||||
## Questions?
|
||||
|
||||
If profiling reveals unexpected bottlenecks or you need help interpreting results, please share:
|
||||
- The full profiling output
|
||||
- Your configuration (RTC enabled/disabled, execution horizon, etc.)
|
||||
- Hardware specs (device type, memory, etc.)
|
||||
- Policy type and size
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
# RTC Profiling - Quick Start
|
||||
|
||||
Quick reference for profiling Pi0 with RTC to identify performance bottlenecks.
|
||||
|
||||
## 🚀 Quick Commands
|
||||
|
||||
### 1. Profile with Real Robot
|
||||
|
||||
```bash
|
||||
# With RTC enabled (profiled version)
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0}, front: {type: opencv, index_or_path: 1}}" \
|
||||
--task="Pick up object" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
### 2. Compare RTC vs No-RTC (No Robot Needed)
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
### 3. Detailed RTC Method Profiling
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
## 📊 What Each Tool Does
|
||||
|
||||
| Tool | Purpose | Needs Robot? |
|
||||
|------|---------|--------------|
|
||||
| `eval_with_real_robot_profiled.py` | Profile actual robot execution with RTC | ✅ Yes |
|
||||
| `profile_rtc_comparison.py` | Compare RTC vs no-RTC side-by-side | ❌ No |
|
||||
| `profile_pi0_rtc_detailed.py` | Deep dive into RTC internals | ❌ No |
|
||||
|
||||
## 🔍 Key Metrics to Watch
|
||||
|
||||
### Overall Performance
|
||||
- **iteration.policy_inference** - Total policy inference time
|
||||
- **iteration.preprocessing** - Image preprocessing time
|
||||
- **iteration.postprocessing** - Action denormalization time
|
||||
|
||||
### RTC-Specific (with `--enable_rtc_profiling`)
|
||||
- **rtc.denoise_step.base_denoising** - Time without RTC overhead
|
||||
- **rtc.denoise_step.autograd_correction** - Gradient computation time
|
||||
- **rtc.denoise_step.guidance_computation** - Total RTC guidance overhead
|
||||
|
||||
### Robot Communication
|
||||
- **robot.get_observation** - Time to get robot state
|
||||
- **robot.send_action** - Time to send action command
|
||||
|
||||
## 🎯 Quick Diagnosis
|
||||
|
||||
### RTC is slower than expected?
|
||||
|
||||
1. **Check if torch.compile is enabled**
|
||||
```bash
|
||||
# Add this flag
|
||||
--use_torch_compile
|
||||
```
|
||||
|
||||
2. **Try larger execution horizon**
|
||||
```bash
|
||||
# Increase to amortize RTC overhead
|
||||
--rtc.execution_horizon=30
|
||||
```
|
||||
|
||||
3. **Profile to find bottleneck**
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
### Preprocessing is slow?
|
||||
|
||||
- Reduce image resolution in robot config
|
||||
- Use fewer cameras
|
||||
- Check camera FPS settings
|
||||
|
||||
### Policy inference is slow?
|
||||
|
||||
- Enable torch.compile
|
||||
- Check device (MPS vs CUDA vs CPU)
|
||||
- Try smaller model if available
|
||||
|
||||
## 📈 Expected Performance
|
||||
|
||||
### Typical timings on Apple Silicon (MPS):
|
||||
|
||||
| Component | Time (ms) | Notes |
|
||||
|-----------|-----------|-------|
|
||||
| Policy inference | 100-200 | Depends on model size |
|
||||
| Preprocessing | 5-20 | Depends on #cameras |
|
||||
| Postprocessing | 1-5 | Usually fast |
|
||||
| RTC overhead | 10-50 | Should be < 50% of base |
|
||||
|
||||
### When RTC helps:
|
||||
- ✅ Execution horizon ≥ 10
|
||||
- ✅ Inference time > action execution rate
|
||||
- ✅ Using torch.compile
|
||||
- ✅ Proper inference_delay calculation
|
||||
|
||||
### When RTC might not help:
|
||||
- ❌ Very fast inference already
|
||||
- ❌ Small execution horizon (< 5)
|
||||
- ❌ No compilation (interpreted mode)
|
||||
- ❌ Inference delay not accounted for
|
||||
|
||||
## 🛠️ Adding Profiling to Your Code
|
||||
|
||||
### Quick snippet:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import enable_profiling, print_profiling_summary, profile_section
|
||||
|
||||
# Enable at start
|
||||
enable_profiling()
|
||||
|
||||
# Profile sections
|
||||
with profile_section("my_operation"):
|
||||
# ... your code ...
|
||||
pass
|
||||
|
||||
# Print at end
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
### Profile specific methods:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method
|
||||
|
||||
@profile_method
|
||||
def my_slow_function():
|
||||
# ... your code ...
|
||||
pass
|
||||
```
|
||||
|
||||
## 📝 Example Output
|
||||
|
||||
```
|
||||
PROFILING SUMMARY
|
||||
================================================================================
|
||||
Function Count Mean (ms)
|
||||
--------------------------------------------------------------------------------
|
||||
iteration.policy_inference 20 150.23
|
||||
iteration.preprocessing 20 12.45
|
||||
rtc.denoise_step.guidance_computation 200 15.67
|
||||
rtc.denoise_step.autograd_correction 200 8.23
|
||||
rtc.denoise_step.base_denoising 200 120.45
|
||||
================================================================================
|
||||
```
|
||||
|
||||
## 🚨 Common Issues
|
||||
|
||||
### "No profiling data available"
|
||||
- Did you call `enable_profiling()`?
|
||||
- Running enough iterations?
|
||||
|
||||
### Inconsistent results
|
||||
- Increase `--num_iterations`
|
||||
- Check for thermal throttling
|
||||
- Close other applications
|
||||
|
||||
### Can't find bottleneck
|
||||
- Enable `--enable_rtc_profiling` for detailed breakdown
|
||||
- Check both preprocessing and inference
|
||||
- Compare with and without RTC
|
||||
|
||||
## 📖 More Details
|
||||
|
||||
See `PROFILING_GUIDE.md` for comprehensive documentation.
|
||||
|
||||
## 🤔 Still Slow?
|
||||
|
||||
1. Run comparison: `profile_rtc_comparison.py`
|
||||
2. Run detailed profiling: `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
|
||||
3. Share output for help (include device, model, settings)
|
||||
|
||||
## ✅ Quick Checklist
|
||||
|
||||
Before asking for help, verify:
|
||||
|
||||
- [ ] Ran comparison script (with/without RTC)
|
||||
- [ ] Tried torch.compile
|
||||
- [ ] Tested different execution horizons (10, 20, 30)
|
||||
- [ ] Profiled with detailed RTC profiling
|
||||
- [ ] Checked preprocessing vs inference split
|
||||
- [ ] Verified hardware (device type, thermal state)
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
# RTC Profiling Toolkit
|
||||
|
||||
Complete toolkit for profiling Pi0 with RTC to identify performance bottlenecks.
|
||||
|
||||
## 📦 What's Included
|
||||
|
||||
### Scripts
|
||||
|
||||
1. **`eval_with_real_robot_profiled.py`**
|
||||
- Profiled version of the real robot eval script
|
||||
- Adds timing measurements throughout execution
|
||||
- Works with actual robot hardware
|
||||
- Same usage as original but with profiling output
|
||||
|
||||
2. **`profile_rtc_comparison.py`**
|
||||
- Side-by-side comparison of RTC vs no-RTC
|
||||
- No robot needed (uses mock observations)
|
||||
- Shows clear verdict on whether RTC is helping
|
||||
- Great for quick performance checks
|
||||
|
||||
3. **`profile_pi0_rtc_detailed.py`**
|
||||
- Most detailed profiling available
|
||||
- Can enable RTC method-level profiling
|
||||
- Provides insights and recommendations
|
||||
- Perfect for deep-dive investigations
|
||||
|
||||
4. **`add_rtc_profiling.py`**
|
||||
- Monkey-patching utility for RTC internals
|
||||
- Profiles individual RTC operations
|
||||
- Can be applied without modifying source
|
||||
- Shows exactly where RTC spends time
|
||||
|
||||
### Utilities
|
||||
|
||||
5. **`src/lerobot/utils/profiling.py`**
|
||||
- Core profiling utilities
|
||||
- Decorators for method profiling
|
||||
- Context managers for code blocks
|
||||
- Statistics collection and reporting
|
||||
|
||||
### Documentation
|
||||
|
||||
6. **`PROFILING_GUIDE.md`** - Comprehensive guide
|
||||
7. **`PROFILING_QUICK_START.md`** - Quick reference
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Step 1: Compare Performance
|
||||
|
||||
Run this first to see if RTC is actually slower:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
COMPARISON SUMMARY
|
||||
================================================================================
|
||||
Metric Without RTC With RTC Difference
|
||||
--------------------------------------------------------------------------------
|
||||
Mean time (ms) 150.23 165.45 +15.22
|
||||
Throughput (iter/s) 6.66 6.05 -0.61
|
||||
================================================================================
|
||||
VERDICT
|
||||
✗ RTC is SLOWER by 10.1%
|
||||
Mean time increased by 15.22 ms
|
||||
|
||||
Possible reasons:
|
||||
- RTC overhead exceeds benefits at current execution horizon
|
||||
- No torch.compile enabled
|
||||
```
|
||||
|
||||
### Step 2: Identify Bottleneck
|
||||
|
||||
If RTC is slower, find out why:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
PROFILING SUMMARY
|
||||
================================================================================
|
||||
Function Count Mean (ms) Total (s)
|
||||
------------------------------------------------------------------------------------
|
||||
iteration.policy_inference 20 150.23 3.00
|
||||
rtc.denoise_step.guidance_computation 200 15.67 3.13
|
||||
rtc.denoise_step.autograd_correction 200 8.23 1.65
|
||||
iteration.preprocessing 20 12.45 0.25
|
||||
================================================================================
|
||||
|
||||
KEY INSIGHTS
|
||||
================================================================================
|
||||
Time breakdown:
|
||||
Policy inference: 150.23 ms (87.2%)
|
||||
Preprocessing: 12.45 ms (7.2%)
|
||||
Postprocessing: 2.10 ms (1.2%)
|
||||
|
||||
RTC breakdown:
|
||||
Base denoising: 120.45 ms
|
||||
Guidance compute: 15.67 ms
|
||||
Autograd correct: 8.23 ms
|
||||
RTC overhead: 23.90 ms (19.8% of base)
|
||||
|
||||
Recommendations:
|
||||
⚠ RTC autograd overhead is significant
|
||||
→ This is expected, but consider increasing execution_horizon
|
||||
→ Try torch.compile if not already enabled
|
||||
💡 torch.compile not enabled
|
||||
→ Try --use_torch_compile for potential speedup
|
||||
================================================================================
|
||||
```
|
||||
|
||||
### Step 3: Try Optimizations
|
||||
|
||||
Based on recommendations:
|
||||
|
||||
```bash
|
||||
# Try with torch.compile
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20 \
|
||||
--use_torch_compile
|
||||
|
||||
# Try larger execution horizon
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=30
|
||||
```
|
||||
|
||||
### Step 4: Profile Real Robot (Optional)
|
||||
|
||||
Test with actual hardware:
|
||||
|
||||
```bash
|
||||
uv run examples/rtc/eval_with_real_robot_profiled.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{...}" \
|
||||
--task="Pick up object" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
## 🎯 Common Scenarios
|
||||
|
||||
### "RTC is 2x slower!"
|
||||
|
||||
This usually means:
|
||||
- RTC overhead is high but not getting benefits
|
||||
- Need to enable torch.compile
|
||||
- Execution horizon too small
|
||||
- Inference delay not calculated correctly
|
||||
|
||||
**Try:**
|
||||
1. `--use_torch_compile`
|
||||
2. Increase `--execution_horizon` to 30+
|
||||
3. Check inference_delay calculation
|
||||
|
||||
### "RTC is only slightly slower"
|
||||
|
||||
This is expected! RTC overhead is about 10-30% typically.
|
||||
The benefit comes during **execution**, not single inference:
|
||||
- Actions are reused across chunks
|
||||
- Overall system latency is reduced
|
||||
- Robot gets smoother actions
|
||||
|
||||
### "Want to optimize specific part"
|
||||
|
||||
Use the profiling utilities:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import enable_profiling, profile_section, print_profiling_summary
|
||||
|
||||
enable_profiling()
|
||||
|
||||
with profile_section("my_custom_operation"):
|
||||
# Your code here
|
||||
pass
|
||||
|
||||
print_profiling_summary()
|
||||
```
|
||||
|
||||
## 📊 Understanding Results
|
||||
|
||||
### Key Metrics
|
||||
|
||||
**Policy Inference Time**
|
||||
- Time for forward pass through model
|
||||
- Should be largest component (70-90%)
|
||||
- Includes RTC guidance if enabled
|
||||
|
||||
**Preprocessing Time**
|
||||
- Image normalization, resizing
|
||||
- Should be < 20% of total
|
||||
- If high: reduce image resolution
|
||||
|
||||
**RTC Guidance Overhead**
|
||||
- Extra time for RTC guidance computation
|
||||
- Typically 10-30% of base inference
|
||||
- If > 50%: RTC may not be beneficial at current settings
|
||||
|
||||
**Autograd Correction**
|
||||
- Time computing gradients for RTC
|
||||
- Usually 5-15% of base inference
|
||||
- Can be reduced with torch.compile
|
||||
|
||||
### Expected Ranges (Apple Silicon MPS)
|
||||
|
||||
| Metric | Good | Acceptable | Poor |
|
||||
|--------|------|------------|------|
|
||||
| Policy inference | 100-150ms | 150-250ms | >250ms |
|
||||
| Preprocessing | <20ms | 20-50ms | >50ms |
|
||||
| RTC overhead | 10-30% | 30-50% | >50% |
|
||||
|
||||
## 🔧 Optimization Guide
|
||||
|
||||
### If RTC overhead is too high:
|
||||
|
||||
1. **Enable compilation:**
|
||||
```bash
|
||||
--use_torch_compile
|
||||
```
|
||||
Expected improvement: 20-40% faster
|
||||
|
||||
2. **Increase execution horizon:**
|
||||
```bash
|
||||
--execution_horizon=30 # or higher
|
||||
```
|
||||
Amortizes RTC cost over more actions
|
||||
|
||||
3. **Check guidance weight:**
|
||||
```python
|
||||
# In config
|
||||
rtc.max_guidance_weight=1.0 # try 0.5 for less overhead
|
||||
```
|
||||
|
||||
### If preprocessing is slow:
|
||||
|
||||
1. **Reduce image resolution:**
|
||||
```python
|
||||
# In robot config
|
||||
cameras={
|
||||
"gripper": {"width": 320, "height": 240} # instead of 640x480
|
||||
}
|
||||
```
|
||||
|
||||
2. **Use fewer cameras:**
|
||||
- Profile which cameras are essential
|
||||
- Remove unnecessary views
|
||||
|
||||
### If inference is generally slow:
|
||||
|
||||
1. Use torch.compile (if not already)
|
||||
2. Check device is correct (MPS vs CUDA)
|
||||
3. Verify model is in eval mode
|
||||
4. Check for unnecessary gradient tracking
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Empty profiling output
|
||||
```python
|
||||
# Make sure to enable profiling!
|
||||
from lerobot.utils.profiling import enable_profiling
|
||||
enable_profiling()
|
||||
```
|
||||
|
||||
### Inconsistent timings
|
||||
- Run more iterations (50-100)
|
||||
- Check thermal throttling
|
||||
- Close background apps
|
||||
- Use `--warmup_iterations=10`
|
||||
|
||||
### Can't find bottleneck
|
||||
1. Start with `profile_rtc_comparison.py`
|
||||
2. Then run `profile_pi0_rtc_detailed.py --enable_rtc_profiling`
|
||||
3. Compare with/without RTC
|
||||
4. Check each component separately
|
||||
|
||||
## 📖 Full Documentation
|
||||
|
||||
- **`PROFILING_GUIDE.md`** - Complete reference with examples
|
||||
- **`PROFILING_QUICK_START.md`** - Quick commands and tips
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
If you're still experiencing issues:
|
||||
|
||||
1. Run comparison script and save output
|
||||
2. Run detailed profiling and save output
|
||||
3. Include:
|
||||
- Policy path
|
||||
- Device type
|
||||
- RTC settings (execution_horizon, etc.)
|
||||
- Hardware specs
|
||||
- Full profiling output
|
||||
|
||||
## 🎓 Learning More
|
||||
|
||||
### Profiling your own code:
|
||||
|
||||
```python
|
||||
from lerobot.utils.profiling import profile_method, enable_profiling
|
||||
|
||||
enable_profiling()
|
||||
|
||||
@profile_method
|
||||
def my_function():
|
||||
# Automatically profiled
|
||||
pass
|
||||
```
|
||||
|
||||
### RTC internals:
|
||||
|
||||
```python
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
|
||||
enable_profiling()
|
||||
monkey_patch_rtc_profiling()
|
||||
|
||||
# Now RTC methods are profiled
|
||||
policy.predict_action_chunk(...)
|
||||
```
|
||||
|
||||
## ✨ Next Steps
|
||||
|
||||
1. Run `profile_rtc_comparison.py` to establish baseline
|
||||
2. Use `profile_pi0_rtc_detailed.py` to find bottlenecks
|
||||
3. Apply optimizations (torch.compile, larger horizon)
|
||||
4. Re-run comparison to verify improvements
|
||||
5. Test with real robot using profiled version
|
||||
|
||||
Happy profiling! 🚀
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
# Real-Time Chunking (RTC) Examples
|
||||
|
||||
This directory contains examples and evaluation scripts for Real-Time Chunking (RTC), a technique for improving action chunking policies in real-time robot control.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-Time Chunking addresses the challenge of maintaining consistency and reactivity when using action chunking policies with non-negligible inference latency. It uses a guidance technique during diffusion sampling to blend new action predictions with previously planned actions.
|
||||
|
||||
**Key Benefits:**
|
||||
|
||||
- Maintains consistency between consecutive action chunks
|
||||
- Reduces jitter and improves smoothness
|
||||
- Adapts to inference delays dynamically
|
||||
|
||||
**Reference:** [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
||||
|
||||
## Scripts
|
||||
|
||||
### 1. `eval_dataset.py`
|
||||
|
||||
Offline evaluation on dataset samples with detailed visualization and validation.
|
||||
|
||||
**Features:**
|
||||
|
||||
- Compare RTC vs non-RTC predictions on two random dataset samples
|
||||
- Validate RTC behavior (delay region, blend region, post-horizon region)
|
||||
- Generate debug visualizations:
|
||||
- Denoising step comparisons (x_t, v_t, x1_t, corrections)
|
||||
- Final action predictions comparison
|
||||
- Support for torch.compile() optimization
|
||||
- Memory-efficient sequential policy loading for large models
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Basic usage with SmolVLA policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
--seed=10
|
||||
|
||||
# With Pi0.5 policy on CUDA
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi05_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With Pi0 policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi0_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=max-autotune
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
--torch_compile_disable_cudagraphs=false
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- `--policy.path`: Path to pretrained policy
|
||||
- `--dataset.repo_id`: Dataset to evaluate on
|
||||
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 20)
|
||||
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 10.0)
|
||||
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
||||
- `--inference_delay`: Inference delay for RTC (default: 4)
|
||||
- `--seed`: Random seed for reproducibility (default: 42)
|
||||
- `--output_dir`: Directory to save visualizations (default: rtc_debug_output)
|
||||
- `--device`: Device to use (cuda, cpu, mps, auto)
|
||||
- `--use_torch_compile`: Enable torch.compile() for faster inference
|
||||
|
||||
**Output:**
|
||||
|
||||
The script generates several visualization files in `rtc_debug_output/`:
|
||||
|
||||
- `denoising_xt_comparison.png` - Noisy state evolution during denoising
|
||||
- `denoising_vt_comparison.png` - Velocity predictions during denoising
|
||||
- `denoising_x1t_comparison.png` - Predicted final states during denoising
|
||||
- `denoising_correction_comparison.png` - RTC guidance corrections applied
|
||||
- `final_actions_comparison.png` - Final action predictions (prev_chunk, no_rtc, rtc)
|
||||
|
||||
The script also validates RTC behavior and reports:
|
||||
|
||||
- ✅ Delay region [0:inference_delay]: RTC = prev_chunk
|
||||
- ✅ Blend region [inference_delay:execution_horizon]: prev_chunk ≤ RTC ≤ no_rtc
|
||||
- ✅ Post-horizon [execution_horizon:]: RTC = no_rtc
|
||||
|
||||
### 2. `eval_with_real_robot.py`
|
||||
|
||||
Real-time evaluation on physical robots or simulation environments.
|
||||
|
||||
**Features:**
|
||||
|
||||
- Run policy with RTC on real robot or simulation
|
||||
- Multi-threaded action execution and inference
|
||||
- Action queue management with proper timing
|
||||
- Latency tracking and adaptive inference delay
|
||||
- Support for both robots and gym environments
|
||||
- Support for torch.compile() optimization
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# With real robot
|
||||
uv run python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--robot.type=so100 \
|
||||
--task="pick up the cup" \
|
||||
--duration=30.0
|
||||
|
||||
# With simulation environment
|
||||
uv run python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--env.type=pusht \
|
||||
--duration=60.0
|
||||
|
||||
# With policy compilation (CUDA only, not MPS)
|
||||
uv run python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--robot.type=so100 \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=max-autotune
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- `--policy.path`: Path to pretrained policy
|
||||
- `--robot.type` or `--env.type`: Robot or environment to use
|
||||
- `--task`: Task description (for VLA models)
|
||||
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
|
||||
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
|
||||
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
||||
- `--duration`: How long to run (seconds, default: 30.0)
|
||||
- `--fps`: Action execution frequency (Hz, default: 10.0)
|
||||
- `--action_queue_size_to_get_new_actions`: Queue size threshold to request new actions (default: 30)
|
||||
- `--device`: Device to use (cuda, cpu, mps, auto)
|
||||
- `--use_torch_compile`: Enable torch.compile() for faster inference
|
||||
|
||||
## Understanding RTC Parameters
|
||||
|
||||
### `execution_horizon`
|
||||
|
||||
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
|
||||
|
||||
**Typical values:** 8-12 steps for dataset evaluation, 10 steps for real-time execution
|
||||
|
||||
### `max_guidance_weight`
|
||||
|
||||
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
|
||||
|
||||
**Typical values:**
|
||||
|
||||
- Dataset evaluation: 10.0-100.0 (can be higher for analysis)
|
||||
- Real-time execution: 1.0-10.0 (more conservative)
|
||||
|
||||
### `prefix_attention_schedule`
|
||||
|
||||
How to weight consistency across the overlap region:
|
||||
|
||||
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
|
||||
- `ONES`: Full weight across entire execution_horizon
|
||||
- `LINEAR`: Linear decay from inference_delay to execution_horizon
|
||||
- `EXP`: Exponential decay (recommended)
|
||||
|
||||
**Recommended:** `EXP`
|
||||
|
||||
### `inference_delay`
|
||||
|
||||
Number of timesteps from the prefix to use for guidance. Typically calculated dynamically based on inference latency in real-time execution, but fixed for dataset evaluation.
|
||||
|
||||
**Typical values:** 3-5 steps for dataset evaluation
|
||||
|
||||
### `action_queue_size_to_get_new_actions` (real-time only)
|
||||
|
||||
Threshold for requesting new action chunks. Should be higher than `inference_delay + execution_horizon` to ensure smooth operation.
|
||||
|
||||
**Typical values:** 20-30 steps
|
||||
|
||||
## Validation Rules (Dataset Evaluation)
|
||||
|
||||
The dataset evaluation script validates that RTC behavior matches expectations:
|
||||
|
||||
1. **Delay Region [0:inference_delay]**: RTC actions should equal previous chunk
|
||||
- Ensures consistency during the inference delay period
|
||||
|
||||
2. **Blend Region [inference_delay:execution_horizon]**: RTC should be between prev_chunk and no_rtc
|
||||
- Smooth transition from previous plan to new predictions
|
||||
|
||||
3. **Post-Horizon [execution_horizon:]**: RTC should equal no_rtc
|
||||
- Full adoption of new predictions after execution horizon
|
||||
|
||||
## Tips
|
||||
|
||||
1. **Start with dataset evaluation** (`eval_dataset.py`) to understand RTC behavior and tune parameters before running on robot
|
||||
2. **Use visualizations** to debug unexpected behavior - check denoising steps and final actions
|
||||
3. **Tune execution_horizon** based on your inference latency and action frequency
|
||||
4. **Monitor validation output** - failures indicate potential implementation issues or misconfigured parameters
|
||||
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Validation fails in delay region
|
||||
|
||||
- Check that `prev_chunk_left_over` is properly passed to the policy
|
||||
- Verify RTC guidance is being applied during denoising
|
||||
- Look at denoising visualizations to see where guidance diverges
|
||||
|
||||
### Validation fails in post-horizon region
|
||||
|
||||
- RTC and no_rtc use different noise - verify same noise is being used for comparison
|
||||
- Check that weights are correctly zeroed out after execution horizon
|
||||
- Review prefix_attention_schedule visualization
|
||||
|
||||
### Poor performance on real robot
|
||||
|
||||
- Increase `action_queue_size_to_get_new_actions` if you see warnings
|
||||
- Reduce `max_guidance_weight` if robot is too conservative
|
||||
- Try different `prefix_attention_schedule` values
|
||||
- Enable torch.compile() for faster inference (CUDA only)
|
||||
|
||||
### Memory issues with large models
|
||||
|
||||
- The dataset evaluation script loads policies sequentially to minimize memory
|
||||
- For real-time execution, only one policy is loaded
|
||||
- Use smaller batch sizes if needed
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
|
||||
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
|
||||
- [Action Queue](../../src/lerobot/policies/rtc/action_queue.py)
|
||||
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
||||
@@ -1,202 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Script to add profiling instrumentation to RTCProcessor.
|
||||
|
||||
This script shows which methods to profile in the RTC code to identify bottlenecks.
|
||||
You can either:
|
||||
1. Apply these changes directly to modeling_rtc.py
|
||||
2. Use monkey patching to add profiling without modifying source
|
||||
3. Use as reference for manual instrumentation
|
||||
|
||||
Usage:
|
||||
# Option 1: Monkey patch (no source changes)
|
||||
python examples/rtc/add_rtc_profiling.py
|
||||
|
||||
# Option 2: Apply changes to source
|
||||
# Copy the profiled methods below into src/lerobot/policies/rtc/modeling_rtc.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.profiling import ProfileContext, enable_profiling, is_profiling_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def profile_denoise_step(self, x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon=None) -> Tensor:
|
||||
"""Profiled version of denoise_step."""
|
||||
|
||||
if not is_profiling_enabled():
|
||||
# Call original implementation if profiling disabled
|
||||
return self._original_denoise_step(x_t, prev_chunk_left_over, inference_delay, time, original_denoise_step_partial, execution_horizon)
|
||||
|
||||
with ProfileContext("rtc.denoise_step.total"):
|
||||
# In the original implementation, the time goes from 0 to 1 and
|
||||
# In our implementation, the time goes from 1 to 0
|
||||
# So we need to invert the time
|
||||
tau = 1 - time
|
||||
|
||||
if prev_chunk_left_over is None:
|
||||
# First step, no guidance - return v_t
|
||||
with ProfileContext("rtc.denoise_step.base_denoising"):
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
return v_t
|
||||
|
||||
with ProfileContext("rtc.denoise_step.setup"):
|
||||
x_t = x_t.clone().detach()
|
||||
|
||||
squeezed = False
|
||||
if len(x_t.shape) < 3:
|
||||
x_t = x_t.unsqueeze(0)
|
||||
squeezed = True
|
||||
|
||||
if len(prev_chunk_left_over.shape) < 3:
|
||||
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
|
||||
|
||||
if execution_horizon is None:
|
||||
execution_horizon = self.rtc_config.execution_horizon
|
||||
|
||||
if execution_horizon > prev_chunk_left_over.shape[1]:
|
||||
execution_horizon = prev_chunk_left_over.shape[1]
|
||||
|
||||
batch_size = x_t.shape[0]
|
||||
action_chunk_size = x_t.shape[1]
|
||||
action_dim = x_t.shape[2]
|
||||
|
||||
# Padding
|
||||
with ProfileContext("rtc.denoise_step.padding"):
|
||||
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
|
||||
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
|
||||
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
|
||||
prev_chunk_left_over = padded
|
||||
|
||||
# Get prefix weights
|
||||
with ProfileContext("rtc.denoise_step.get_prefix_weights"):
|
||||
weights = (
|
||||
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
|
||||
.to(x_t.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
||||
# Main RTC guidance computation
|
||||
with ProfileContext("rtc.denoise_step.guidance_computation"):
|
||||
with torch.enable_grad():
|
||||
# Base denoising
|
||||
with ProfileContext("rtc.denoise_step.base_denoising"):
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
|
||||
x_t.requires_grad_(True)
|
||||
|
||||
# Compute x1_t
|
||||
with ProfileContext("rtc.denoise_step.compute_x1_t"):
|
||||
x1_t = x_t - time * v_t
|
||||
|
||||
# Compute error
|
||||
with ProfileContext("rtc.denoise_step.compute_error"):
|
||||
err = (prev_chunk_left_over - x1_t) * weights
|
||||
grad_outputs = err.clone().detach()
|
||||
|
||||
# Compute correction via autograd
|
||||
with ProfileContext("rtc.denoise_step.autograd_correction"):
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
# Compute guidance weight
|
||||
with ProfileContext("rtc.denoise_step.compute_guidance_weight"):
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
|
||||
|
||||
# Apply guidance
|
||||
with ProfileContext("rtc.denoise_step.apply_guidance"):
|
||||
result = v_t - guidance_weight * correction
|
||||
|
||||
# Cleanup
|
||||
with ProfileContext("rtc.denoise_step.cleanup"):
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
correction = correction.squeeze(0)
|
||||
x1_t = x1_t.squeeze(0)
|
||||
err = err.squeeze(0)
|
||||
|
||||
self.track(
|
||||
time=time,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def monkey_patch_rtc_profiling():
|
||||
"""Apply profiling to RTCProcessor via monkey patching.
|
||||
|
||||
This modifies the RTCProcessor class at runtime to add profiling
|
||||
without changing source files.
|
||||
"""
|
||||
logger.info("Applying RTC profiling monkey patch...")
|
||||
|
||||
# Save original method
|
||||
RTCProcessor._original_denoise_step = RTCProcessor.denoise_step
|
||||
|
||||
# Replace with profiled version
|
||||
RTCProcessor.denoise_step = profile_denoise_step
|
||||
|
||||
logger.info("✓ RTC profiling enabled")
|
||||
|
||||
|
||||
def print_usage():
|
||||
"""Print usage instructions."""
|
||||
print("\n" + "="*80)
|
||||
print("RTC PROFILING INSTRUMENTATION")
|
||||
print("="*80)
|
||||
print("\nThis script provides profiling for RTCProcessor methods.")
|
||||
print("\nOption 1: Monkey Patch (Recommended)")
|
||||
print("-" * 40)
|
||||
print("Add to your script:")
|
||||
print("""
|
||||
from lerobot.utils.profiling import enable_profiling, print_profiling_summary
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
monkey_patch_rtc_profiling()
|
||||
|
||||
# ... run your code ...
|
||||
|
||||
# Print results
|
||||
print_profiling_summary()
|
||||
""")
|
||||
|
||||
print("\nOption 2: Manual Source Modification")
|
||||
print("-" * 40)
|
||||
print("1. Copy profile_denoise_step() from this file")
|
||||
print("2. Replace denoise_step() in src/lerobot/policies/rtc/modeling_rtc.py")
|
||||
print("3. Add profiling imports at top of file")
|
||||
|
||||
print("\nKey Metrics to Watch:")
|
||||
print("-" * 40)
|
||||
print("- rtc.denoise_step.base_denoising - Time for base policy inference")
|
||||
print("- rtc.denoise_step.autograd_correction - Time computing gradients")
|
||||
print("- rtc.denoise_step.guidance_computation - Total guidance overhead")
|
||||
print("- rtc.denoise_step.get_prefix_weights - Time computing weights")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_usage()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,549 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
For simulation environments, see eval_with_simulation.py
|
||||
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
# The stats are embedded in the processor .safetensors files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
@@ -1,631 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Profiled version of eval_with_real_robot.py for performance analysis.
|
||||
|
||||
This version adds detailed timing measurements for:
|
||||
- Policy inference
|
||||
- Preprocessing
|
||||
- Postprocessing
|
||||
- Action queue operations
|
||||
- Robot communication
|
||||
- Thread execution times
|
||||
|
||||
Usage: Same as eval_with_real_robot.py but with profiling output.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileTimer:
|
||||
"""Context manager and utility class for timing code sections."""
|
||||
|
||||
def __init__(self, name: str, stats_dict: dict):
|
||||
self.name = name
|
||||
self.stats_dict = stats_dict
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
elapsed = time.perf_counter() - self.start_time
|
||||
if self.name not in self.stats_dict:
|
||||
self.stats_dict[self.name] = []
|
||||
self.stats_dict[self.name].append(elapsed)
|
||||
|
||||
|
||||
class ProfilingStats:
|
||||
"""Global profiling statistics collector."""
|
||||
|
||||
def __init__(self):
|
||||
self.stats = defaultdict(list)
|
||||
self.lock = Lock()
|
||||
|
||||
def record(self, name: str, duration: float):
|
||||
with self.lock:
|
||||
self.stats[name].append(duration)
|
||||
|
||||
def timer(self, name: str):
|
||||
"""Return a context manager for timing."""
|
||||
return ProfileTimer(name, self.stats)
|
||||
|
||||
def get_summary(self) -> dict[str, dict[str, float]]:
|
||||
"""Get summary statistics for all timings."""
|
||||
with self.lock:
|
||||
summary = {}
|
||||
for name, times in self.stats.items():
|
||||
if times:
|
||||
summary[name] = {
|
||||
"count": len(times),
|
||||
"mean": sum(times) / len(times),
|
||||
"min": min(times),
|
||||
"max": max(times),
|
||||
"total": sum(times),
|
||||
}
|
||||
return summary
|
||||
|
||||
def print_summary(self):
|
||||
"""Print formatted summary of all timings."""
|
||||
summary = self.get_summary()
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
logger.info("PROFILING SUMMARY")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Sort by total time (descending)
|
||||
sorted_items = sorted(summary.items(), key=lambda x: x[1]["total"], reverse=True)
|
||||
|
||||
for name, stats in sorted_items:
|
||||
logger.info(f"\n{name}:")
|
||||
logger.info(f" Count: {stats['count']}")
|
||||
logger.info(f" Mean: {stats['mean']*1000:.2f} ms")
|
||||
logger.info(f" Min: {stats['min']*1000:.2f} ms")
|
||||
logger.info(f" Max: {stats['max']*1000:.2f} ms")
|
||||
logger.info(f" Total: {stats['total']:.2f} s")
|
||||
logger.info(f" Hz: {stats['count']/stats['total']:.2f}")
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
|
||||
|
||||
# Global profiling stats
|
||||
profiling_stats = ProfilingStats()
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with profiling_stats.timer("robot.get_observation"):
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with profiling_stats.timer("robot.send_action"):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy with profiling.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
with profiling_stats.timer("get_actions.total_iteration"):
|
||||
inference_count += 1
|
||||
logger.info(f"[GET_ACTIONS] Starting inference #{inference_count}")
|
||||
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
|
||||
with profiling_stats.timer("get_actions.get_prev_actions"):
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
# Get observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
with profiling_stats.timer("get_actions.robot_obs_processing"):
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
# Build dataset frame
|
||||
with profiling_stats.timer("get_actions.build_dataset_frame"):
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
# Convert to tensors and normalize
|
||||
with profiling_stats.timer("get_actions.tensor_conversion"):
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task]
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
# Preprocessing
|
||||
with profiling_stats.timer("get_actions.preprocessing"):
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Policy inference
|
||||
with profiling_stats.timer("get_actions.policy_inference"):
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Clone for RTC
|
||||
with profiling_stats.timer("get_actions.clone_actions"):
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
# Postprocessing
|
||||
with profiling_stats.timer("get_actions.postprocessing"):
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
# Update latency tracker
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
logger.info(
|
||||
f"[GET_ACTIONS] Inference #{inference_count} completed in {new_latency*1000:.2f}ms "
|
||||
f"(delay={new_delay} chunks)"
|
||||
)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, "
|
||||
"It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
# Merge into action queue
|
||||
with profiling_stats.timer("get_actions.action_queue_merge"):
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot with profiling.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with profiling_stats.timer("actor.total_iteration"):
|
||||
# Get action from queue
|
||||
with profiling_stats.timer("actor.queue_get"):
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
# Process action
|
||||
with profiling_stats.timer("actor.action_processing"):
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
|
||||
# Send to robot (includes robot.send_action timing)
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
# Sleep to maintain target FPS
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, (action_interval - dt_s) - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with profiling."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
logger.info("=" * 80)
|
||||
logger.info("PROFILING MODE ENABLED")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processor
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
# Print profiling summary
|
||||
profiling_stats.print_summary()
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
|
||||
@@ -1,358 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Comprehensive profiling script for Pi0 with RTC.
|
||||
|
||||
This script demonstrates how to use all the profiling tools to identify
|
||||
bottlenecks in Pi0 policy inference with RTC enabled.
|
||||
|
||||
It profiles:
|
||||
1. Overall inference time
|
||||
2. RTC-specific operations (guidance, weights, etc.)
|
||||
3. Preprocessing/postprocessing
|
||||
4. Individual method timings
|
||||
|
||||
Usage:
|
||||
uv run examples/rtc/profile_pi0_rtc_detailed.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=20 \
|
||||
--execution_horizon=20 \
|
||||
--enable_rtc_profiling
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.profiling import (
|
||||
ProfileContext,
|
||||
clear_profiling_stats,
|
||||
enable_profiling,
|
||||
get_profiling_stats,
|
||||
print_profiling_summary,
|
||||
)
|
||||
|
||||
# Import monkey patching for RTC profiling
|
||||
try:
|
||||
from examples.rtc.add_rtc_profiling import monkey_patch_rtc_profiling
|
||||
except ImportError:
|
||||
logging.warning("Could not import add_rtc_profiling, detailed RTC profiling disabled")
|
||||
monkey_patch_rtc_profiling = None
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mock_observation(policy_config, device: str) -> dict:
|
||||
"""Create a mock observation matching policy requirements.
|
||||
|
||||
Args:
|
||||
policy_config: Policy configuration
|
||||
device: Device to create tensors on
|
||||
|
||||
Returns:
|
||||
Mock observation dictionary
|
||||
"""
|
||||
obs = {}
|
||||
|
||||
# Create mock state observation
|
||||
state_dim = 10 # Typical robot state dimension
|
||||
obs["observation.state"] = torch.randn(1, state_dim, device=device)
|
||||
|
||||
# Create mock images if needed
|
||||
# For Pi0, we typically need at least one image
|
||||
image_height = 224
|
||||
image_width = 224
|
||||
|
||||
# Common image keys for Pi0
|
||||
image_keys = ["observation.images.gripper", "observation.images.front"]
|
||||
|
||||
for key in image_keys:
|
||||
# Images should be [B, C, H, W] and normalized to [0, 1]
|
||||
obs[key] = torch.rand(1, 3, image_height, image_width, device=device)
|
||||
|
||||
# Add task
|
||||
obs["task"] = ["Pick up the object"]
|
||||
|
||||
# Add language tokens and attention mask (required for Pi0)
|
||||
# These are mock values - in real usage they come from tokenizer
|
||||
max_seq_len = 32
|
||||
obs["observation.language_tokens"] = torch.randint(0, 1000, (1, max_seq_len), device=device)
|
||||
obs["observation.language_attention_mask"] = torch.ones(1, max_seq_len, device=device)
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def profile_single_iteration(
|
||||
policy,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
observation: dict,
|
||||
prev_actions: torch.Tensor | None,
|
||||
use_rtc: bool,
|
||||
inference_delay: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, dict]:
|
||||
"""Profile a single inference iteration.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
preprocessor: Observation preprocessor
|
||||
postprocessor: Action postprocessor
|
||||
observation: Input observation
|
||||
prev_actions: Previous action chunk (for RTC)
|
||||
use_rtc: Whether RTC is enabled
|
||||
inference_delay: Inference delay in timesteps
|
||||
|
||||
Returns:
|
||||
Tuple of (actions, new_prev_actions, timings)
|
||||
"""
|
||||
timings = {}
|
||||
|
||||
with ProfileContext("iteration.total"):
|
||||
# Preprocessing
|
||||
with ProfileContext("iteration.preprocessing"):
|
||||
preprocessed_obs = preprocessor(observation)
|
||||
|
||||
# Policy inference
|
||||
with ProfileContext("iteration.policy_inference"):
|
||||
if use_rtc:
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
else:
|
||||
actions = policy.predict_action_chunk(preprocessed_obs)
|
||||
|
||||
# Clone for next iteration (if RTC)
|
||||
new_prev_actions = None
|
||||
if use_rtc:
|
||||
with ProfileContext("iteration.prepare_prev_actions"):
|
||||
execution_horizon = policy.config.rtc_config.execution_horizon
|
||||
if actions.shape[1] > execution_horizon:
|
||||
new_prev_actions = actions[:, execution_horizon:].clone()
|
||||
|
||||
# Postprocessing
|
||||
with ProfileContext("iteration.postprocessing"):
|
||||
processed_actions = postprocessor(actions)
|
||||
|
||||
return processed_actions, new_prev_actions, timings
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Detailed profiling for Pi0 with RTC")
|
||||
parser.add_argument("--policy_path", type=str, required=True, help="Path to pretrained policy")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu/mps)")
|
||||
parser.add_argument("--num_iterations", type=int, default=20, help="Number of iterations")
|
||||
parser.add_argument("--execution_horizon", type=int, default=10, help="RTC execution horizon")
|
||||
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--enable_rtc_profiling", action="store_true", help="Enable detailed RTC profiling")
|
||||
parser.add_argument("--use_torch_compile", action="store_true", help="Use torch.compile")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("DETAILED PI0 RTC PROFILING")
|
||||
logger.info("="*80)
|
||||
logger.info(f"Policy: {args.policy_path}")
|
||||
logger.info(f"Device: {args.device}")
|
||||
logger.info(f"Iterations: {args.num_iterations}")
|
||||
logger.info(f"Execution Horizon: {args.execution_horizon}")
|
||||
logger.info(f"RTC Profiling: {args.enable_rtc_profiling}")
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
# Enable profiling
|
||||
enable_profiling()
|
||||
|
||||
# Apply RTC profiling if requested
|
||||
if args.enable_rtc_profiling:
|
||||
if monkey_patch_rtc_profiling is not None:
|
||||
monkey_patch_rtc_profiling()
|
||||
logger.info("✓ Detailed RTC profiling enabled\n")
|
||||
else:
|
||||
logger.warning("⚠ Could not enable detailed RTC profiling\n")
|
||||
|
||||
# Load policy
|
||||
logger.info("Loading policy...")
|
||||
config = PreTrainedConfig.from_pretrained(args.policy_path)
|
||||
|
||||
if hasattr(config, "compile_model"):
|
||||
config.compile_model = args.use_torch_compile
|
||||
|
||||
policy_class = get_policy_class(config.type)
|
||||
policy = policy_class.from_pretrained(args.policy_path, config=config)
|
||||
|
||||
# Configure RTC
|
||||
policy.config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=args.execution_horizon,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(args.device)
|
||||
policy.eval()
|
||||
|
||||
logger.info(f"✓ Policy loaded: {config.type}\n")
|
||||
|
||||
# Create preprocessor and postprocessor
|
||||
logger.info("Loading preprocessor/postprocessor...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=config,
|
||||
pretrained_path=args.policy_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": args.device},
|
||||
},
|
||||
)
|
||||
logger.info("✓ Preprocessor/postprocessor loaded\n")
|
||||
|
||||
# Create mock observation
|
||||
logger.info("Creating mock observation...")
|
||||
observation = create_mock_observation(config, args.device)
|
||||
logger.info("✓ Mock observation created\n")
|
||||
|
||||
# Warmup
|
||||
logger.info(f"Warming up ({args.warmup_iterations} iterations)...")
|
||||
prev_actions = None
|
||||
for i in range(args.warmup_iterations):
|
||||
with torch.no_grad():
|
||||
_, prev_actions, _ = profile_single_iteration(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
observation=observation,
|
||||
prev_actions=prev_actions,
|
||||
use_rtc=True,
|
||||
inference_delay=0,
|
||||
)
|
||||
|
||||
# Clear warmup stats
|
||||
clear_profiling_stats()
|
||||
logger.info("✓ Warmup complete\n")
|
||||
|
||||
# Profiled run WITH RTC
|
||||
logger.info(f"Running profiled iterations WITH RTC ({args.num_iterations} iterations)...")
|
||||
prev_actions = None
|
||||
iteration_times = []
|
||||
|
||||
for i in range(args.num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
_, prev_actions, _ = profile_single_iteration(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
observation=observation,
|
||||
prev_actions=prev_actions,
|
||||
use_rtc=True,
|
||||
inference_delay=0,
|
||||
)
|
||||
|
||||
# Sync CUDA if needed
|
||||
if args.device.startswith("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
iteration_times.append(elapsed)
|
||||
|
||||
if (i + 1) % 5 == 0:
|
||||
logger.info(f" Completed {i+1}/{args.num_iterations}")
|
||||
|
||||
logger.info("✓ Profiling complete\n")
|
||||
|
||||
# Print summary statistics
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("ITERATION TIMING SUMMARY")
|
||||
logger.info("="*80)
|
||||
|
||||
times_arr = np.array(iteration_times)
|
||||
logger.info(f"Mean time: {np.mean(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Median time: {np.median(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Std dev: {np.std(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Min time: {np.min(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Max time: {np.max(times_arr)*1000:.2f} ms")
|
||||
logger.info(f"Total time: {np.sum(times_arr):.2f} s")
|
||||
logger.info(f"Throughput: {len(times_arr)/np.sum(times_arr):.2f} iter/s")
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
# Print detailed profiling breakdown
|
||||
print_profiling_summary(sort_by="total")
|
||||
|
||||
# Print key insights
|
||||
stats = get_profiling_stats()
|
||||
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("KEY INSIGHTS")
|
||||
logger.info("="*80)
|
||||
|
||||
# Find bottlenecks
|
||||
if stats:
|
||||
policy_inference_time = stats.get("iteration.policy_inference", {}).get("mean", 0)
|
||||
preprocessing_time = stats.get("iteration.preprocessing", {}).get("mean", 0)
|
||||
postprocessing_time = stats.get("iteration.postprocessing", {}).get("mean", 0)
|
||||
|
||||
total_time = policy_inference_time + preprocessing_time + postprocessing_time
|
||||
|
||||
if total_time > 0:
|
||||
logger.info(f"\nTime breakdown:")
|
||||
logger.info(f" Policy inference: {policy_inference_time*1000:.2f} ms ({policy_inference_time/total_time*100:.1f}%)")
|
||||
logger.info(f" Preprocessing: {preprocessing_time*1000:.2f} ms ({preprocessing_time/total_time*100:.1f}%)")
|
||||
logger.info(f" Postprocessing: {postprocessing_time*1000:.2f} ms ({postprocessing_time/total_time*100:.1f}%)")
|
||||
|
||||
# RTC-specific insights
|
||||
if args.enable_rtc_profiling:
|
||||
rtc_guidance = stats.get("rtc.denoise_step.guidance_computation", {}).get("mean", 0)
|
||||
rtc_autograd = stats.get("rtc.denoise_step.autograd_correction", {}).get("mean", 0)
|
||||
rtc_base = stats.get("rtc.denoise_step.base_denoising", {}).get("mean", 0)
|
||||
|
||||
if rtc_guidance > 0:
|
||||
logger.info(f"\nRTC breakdown:")
|
||||
logger.info(f" Base denoising: {rtc_base*1000:.2f} ms")
|
||||
logger.info(f" Guidance compute: {rtc_guidance*1000:.2f} ms")
|
||||
logger.info(f" Autograd correct: {rtc_autograd*1000:.2f} ms")
|
||||
logger.info(f" RTC overhead: {(rtc_guidance - rtc_base)*1000:.2f} ms")
|
||||
|
||||
# Recommendations
|
||||
logger.info("\nRecommendations:")
|
||||
|
||||
if preprocessing_time > policy_inference_time * 0.3:
|
||||
logger.info(" ⚠ Preprocessing is taking >30% of time")
|
||||
logger.info(" → Consider reducing image resolution")
|
||||
logger.info(" → Consider using fewer cameras")
|
||||
|
||||
if args.enable_rtc_profiling and rtc_autograd > rtc_base * 0.5:
|
||||
logger.info(" ⚠ RTC autograd overhead is significant")
|
||||
logger.info(" → This is expected, but consider increasing execution_horizon")
|
||||
logger.info(" → Try torch.compile if not already enabled")
|
||||
|
||||
if not args.use_torch_compile:
|
||||
logger.info(" 💡 torch.compile not enabled")
|
||||
logger.info(" → Try --use_torch_compile for potential speedup")
|
||||
|
||||
logger.info("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n\nProfiling interrupted by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"\n\nError during profiling: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,347 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Script to compare performance with and without RTC enabled.
|
||||
|
||||
This script helps identify whether RTC is actually improving or degrading performance
|
||||
by running multiple inference passes and collecting detailed timing statistics.
|
||||
|
||||
Usage:
|
||||
# Profile with mock data (no robot needed)
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50
|
||||
|
||||
# Profile with specific RTC config
|
||||
uv run examples/rtc/profile_rtc_comparison.py \
|
||||
--policy_path=helper2424/pi05_check_rtc \
|
||||
--device=mps \
|
||||
--num_iterations=50 \
|
||||
--execution_horizon=20
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.profiling import (
|
||||
clear_profiling_stats,
|
||||
enable_profiling,
|
||||
get_profiling_stats,
|
||||
print_profiling_summary,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileResults:
|
||||
"""Results from profiling run."""
|
||||
|
||||
mode: str # "with_rtc" or "without_rtc"
|
||||
mean_time: float
|
||||
std_time: float
|
||||
min_time: float
|
||||
max_time: float
|
||||
times: list[float]
|
||||
throughput: float # iterations per second
|
||||
|
||||
|
||||
def create_mock_observation(policy, device: str) -> dict:
|
||||
"""Create a mock observation for testing.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
device: Device to create tensors on
|
||||
|
||||
Returns:
|
||||
Mock observation dictionary
|
||||
"""
|
||||
# Get expected input shapes from policy config
|
||||
# This is a simplified version - adjust based on actual policy requirements
|
||||
obs = {}
|
||||
|
||||
# Mock image observations (if needed)
|
||||
if hasattr(policy.config, "input_shapes"):
|
||||
for key, shape in policy.config.input_shapes.items():
|
||||
if "image" in key:
|
||||
# Typical image shape: (batch, channels, height, width)
|
||||
obs[key] = torch.randn(1, *shape, device=device)
|
||||
else:
|
||||
obs[key] = torch.randn(1, *shape, device=device)
|
||||
|
||||
# Add task if needed
|
||||
if "task" in policy.config.__dict__ or hasattr(policy, "accepts_task"):
|
||||
obs["task"] = ["Pick up the object"]
|
||||
|
||||
# Mock state observation
|
||||
obs["observation.state"] = torch.randn(1, 10, device=device) # Adjust size as needed
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def profile_inference(
|
||||
policy, observation: dict, num_iterations: int, use_rtc: bool, execution_horizon: int = 10
|
||||
) -> ProfileResults:
|
||||
"""Profile policy inference with or without RTC.
|
||||
|
||||
Args:
|
||||
policy: Policy instance
|
||||
observation: Observation dictionary
|
||||
num_iterations: Number of inference iterations to run
|
||||
use_rtc: Whether to enable RTC
|
||||
execution_horizon: Execution horizon for RTC
|
||||
|
||||
Returns:
|
||||
ProfileResults with timing statistics
|
||||
"""
|
||||
mode = "with_rtc" if use_rtc else "without_rtc"
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info(f"Profiling: {mode.upper()}")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
# Configure RTC
|
||||
if use_rtc:
|
||||
policy.config.rtc_config.enabled = True
|
||||
policy.config.rtc_config.execution_horizon = execution_horizon
|
||||
policy.init_rtc_processor()
|
||||
else:
|
||||
policy.config.rtc_config.enabled = False
|
||||
|
||||
times = []
|
||||
prev_actions = None
|
||||
|
||||
# Warmup
|
||||
logger.info("Warming up (5 iterations)...")
|
||||
for _ in range(5):
|
||||
with torch.no_grad():
|
||||
if use_rtc:
|
||||
_ = policy.predict_action_chunk(
|
||||
observation, inference_delay=0, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
else:
|
||||
_ = policy.predict_action_chunk(observation)
|
||||
|
||||
# Actual profiling
|
||||
logger.info(f"Running {num_iterations} profiled iterations...")
|
||||
for i in range(num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
if use_rtc:
|
||||
actions = policy.predict_action_chunk(
|
||||
observation, inference_delay=0, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
# Simulate consuming some actions for next iteration
|
||||
if actions.shape[1] > execution_horizon:
|
||||
prev_actions = actions[:, execution_horizon:].clone()
|
||||
else:
|
||||
prev_actions = None
|
||||
else:
|
||||
actions = policy.predict_action_chunk(observation)
|
||||
|
||||
# Synchronize if using CUDA
|
||||
if observation["observation.state"].device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
logger.info(f" Completed {i+1}/{num_iterations} iterations")
|
||||
|
||||
# Calculate statistics
|
||||
times_arr = np.array(times)
|
||||
results = ProfileResults(
|
||||
mode=mode,
|
||||
mean_time=float(np.mean(times_arr)),
|
||||
std_time=float(np.std(times_arr)),
|
||||
min_time=float(np.min(times_arr)),
|
||||
max_time=float(np.max(times_arr)),
|
||||
times=times,
|
||||
throughput=num_iterations / sum(times),
|
||||
)
|
||||
|
||||
logger.info(f"\nResults for {mode}:")
|
||||
logger.info(f" Mean time: {results.mean_time*1000:.2f} ms")
|
||||
logger.info(f" Std dev: {results.std_time*1000:.2f} ms")
|
||||
logger.info(f" Min time: {results.min_time*1000:.2f} ms")
|
||||
logger.info(f" Max time: {results.max_time*1000:.2f} ms")
|
||||
logger.info(f" Throughput: {results.throughput:.2f} iter/s")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_results(results_without_rtc: ProfileResults, results_with_rtc: ProfileResults):
|
||||
"""Compare and print results from both runs.
|
||||
|
||||
Args:
|
||||
results_without_rtc: Results from run without RTC
|
||||
results_with_rtc: Results from run with RTC
|
||||
"""
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info("COMPARISON SUMMARY")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
mean_diff = results_with_rtc.mean_time - results_without_rtc.mean_time
|
||||
mean_diff_pct = (mean_diff / results_without_rtc.mean_time) * 100
|
||||
|
||||
throughput_diff = results_with_rtc.throughput - results_without_rtc.throughput
|
||||
throughput_diff_pct = (throughput_diff / results_without_rtc.throughput) * 100
|
||||
|
||||
logger.info(f"\n{'Metric':<30} {'Without RTC':>15} {'With RTC':>15} {'Difference':>15}")
|
||||
logger.info("-" * 80)
|
||||
logger.info(
|
||||
f"{'Mean time (ms)':<30} "
|
||||
f"{results_without_rtc.mean_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.mean_time*1000:>15.2f} "
|
||||
f"{mean_diff*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Std dev (ms)':<30} "
|
||||
f"{results_without_rtc.std_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.std_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.std_time - results_without_rtc.std_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Min time (ms)':<30} "
|
||||
f"{results_without_rtc.min_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.min_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.min_time - results_without_rtc.min_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Max time (ms)':<30} "
|
||||
f"{results_without_rtc.max_time*1000:>15.2f} "
|
||||
f"{results_with_rtc.max_time*1000:>15.2f} "
|
||||
f"{(results_with_rtc.max_time - results_without_rtc.max_time)*1000:>+15.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"{'Throughput (iter/s)':<30} "
|
||||
f"{results_without_rtc.throughput:>15.2f} "
|
||||
f"{results_with_rtc.throughput:>15.2f} "
|
||||
f"{throughput_diff:>+15.2f}"
|
||||
)
|
||||
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info("VERDICT")
|
||||
logger.info(f"{'='*80}")
|
||||
|
||||
if mean_diff_pct < -5:
|
||||
logger.info(f"✓ RTC is FASTER by {abs(mean_diff_pct):.1f}%")
|
||||
logger.info(f" Mean time reduced by {abs(mean_diff)*1000:.2f} ms")
|
||||
elif mean_diff_pct > 5:
|
||||
logger.info(f"✗ RTC is SLOWER by {mean_diff_pct:.1f}%")
|
||||
logger.info(f" Mean time increased by {mean_diff*1000:.2f} ms")
|
||||
logger.info("\n Possible reasons:")
|
||||
logger.info(" - RTC overhead exceeds benefits at current execution horizon")
|
||||
logger.info(" - Inference delay calculation not accounting for RTC processing")
|
||||
logger.info(" - Additional tensor operations in RTC guidance")
|
||||
else:
|
||||
logger.info(f"≈ Performance is SIMILAR (difference: {mean_diff_pct:+.1f}%)")
|
||||
|
||||
logger.info(f"{'='*80}\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Profile RTC performance")
|
||||
parser.add_argument(
|
||||
"--policy_path", type=str, required=True, help="Path to pretrained policy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="Device to run on (cuda/cpu/mps)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_iterations", type=int, default=50, help="Number of inference iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--execution_horizon", type=int, default=10, help="RTC execution horizon"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_detailed_profiling",
|
||||
action="store_true",
|
||||
help="Enable detailed method-level profiling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_torch_compile", action="store_true", help="Use torch.compile for faster inference"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from {args.policy_path}")
|
||||
config = PreTrainedConfig.from_pretrained(args.policy_path)
|
||||
policy_class = get_policy_class(config.type)
|
||||
|
||||
# Set compile flag if needed
|
||||
if hasattr(config, "compile_model"):
|
||||
config.compile_model = args.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(args.policy_path, config=config)
|
||||
|
||||
# Initialize RTC config
|
||||
policy.config.rtc_config = RTCConfig(
|
||||
execution_horizon=args.execution_horizon,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
|
||||
policy = policy.to(args.device)
|
||||
policy.eval()
|
||||
|
||||
logger.info(f"Policy loaded: {config.type}")
|
||||
logger.info(f"Device: {args.device}")
|
||||
logger.info(f"Execution horizon: {args.execution_horizon}")
|
||||
|
||||
# Create mock observation
|
||||
logger.info("Creating mock observation...")
|
||||
observation = create_mock_observation(policy, args.device)
|
||||
|
||||
# Enable detailed profiling if requested
|
||||
if args.enable_detailed_profiling:
|
||||
enable_profiling()
|
||||
logger.info("Detailed profiling enabled")
|
||||
|
||||
# Profile without RTC
|
||||
results_without_rtc = profile_inference(
|
||||
policy=policy,
|
||||
observation=observation,
|
||||
num_iterations=args.num_iterations,
|
||||
use_rtc=False,
|
||||
execution_horizon=args.execution_horizon,
|
||||
)
|
||||
|
||||
if args.enable_detailed_profiling:
|
||||
logger.info("\nDetailed profiling stats (WITHOUT RTC):")
|
||||
print_profiling_summary()
|
||||
clear_profiling_stats()
|
||||
|
||||
# Profile with RTC
|
||||
results_with_rtc = profile_inference(
|
||||
policy=policy,
|
||||
observation=observation,
|
||||
num_iterations=args.num_iterations,
|
||||
use_rtc=True,
|
||||
execution_horizon=args.execution_horizon,
|
||||
)
|
||||
|
||||
if args.enable_detailed_profiling:
|
||||
logger.info("\nDetailed profiling stats (WITH RTC):")
|
||||
print_profiling_summary()
|
||||
|
||||
# Compare results
|
||||
compare_results(results_without_rtc, results_with_rtc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
@@ -1,57 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
@@ -1,11 +0,0 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
@@ -1,55 +0,0 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
@@ -1,99 +0,0 @@
|
||||
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_diffusion"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
@@ -1,67 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
@@ -1,345 +0,0 @@
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
SEND_EVERY = 10
|
||||
|
||||
|
||||
def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
batch_size: int = 32,
|
||||
device: torch.device = "mps",
|
||||
):
|
||||
"""The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
|
||||
for the actor to adopt."""
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
|
||||
training_step = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
# retrieve incoming transitions from the actor process
|
||||
try:
|
||||
transitions = transitions_queue.get(timeout=0.1)
|
||||
for transition in transitions:
|
||||
# HIL-SERL: Add ALL transitions to online buffer
|
||||
online_buffer.add(**transition)
|
||||
|
||||
# HIL-SERL: Add ONLY human intervention transitions to offline buffer
|
||||
is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
|
||||
if is_intervention:
|
||||
offline_buffer.add(**transition)
|
||||
print(
|
||||
f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
|
||||
)
|
||||
|
||||
except Empty:
|
||||
pass # No transitions available, continue
|
||||
|
||||
# Train if we have enough data
|
||||
if len(online_buffer) >= policy_learner.config.online_step_before_learning:
|
||||
# Sample from online buffer (autonomous + human data)
|
||||
online_batch = online_buffer.sample(batch_size // 2)
|
||||
|
||||
# Sample from offline buffer (human demonstrations only, either precollected or at runtime)
|
||||
offline_batch = offline_buffer.sample(batch_size // 2)
|
||||
|
||||
# Combine batches - this is the key HIL-SERL mechanism!
|
||||
batch = {}
|
||||
for key in online_batch:
|
||||
if key in offline_batch:
|
||||
batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
pass
|
||||
|
||||
print("[LEARNER] Learner process finished")
|
||||
|
||||
|
||||
def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
output_directory: Path | None = None,
|
||||
):
|
||||
"""The actor process - interacts with environment and collects data.
|
||||
The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
|
||||
policy_actor.eval()
|
||||
policy_actor.to(device)
|
||||
|
||||
reward_classifier.eval()
|
||||
reward_classifier.to(device)
|
||||
|
||||
# Create robot environment inside the actor process
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
try:
|
||||
for episode in range(MAX_EPISODES):
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs, _info = env.reset()
|
||||
episode_reward = 0.0
|
||||
step = 0
|
||||
episode_transitions = []
|
||||
|
||||
print(f"[ACTOR] Starting episode {episode + 1}")
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
next_obs, _env_reward, terminated, truncated, _info = env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
# Predict reward
|
||||
policy_next_obs = make_policy_obs(next_obs, device=device)
|
||||
reward = reward_classifier.predict_reward(policy_next_obs)
|
||||
|
||||
if reward >= 1.0 and not done: # success detected! halt episode
|
||||
terminated = True
|
||||
done = True
|
||||
|
||||
# In HIL-SERL, human interventions come from the teleop device
|
||||
is_intervention = False
|
||||
if hasattr(teleop_device, "get_teleop_events"):
|
||||
# Real intervention detection from teleop device
|
||||
teleop_events = teleop_device.get_teleop_events()
|
||||
is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
|
||||
# Store transition with intervention metadata
|
||||
transition = {
|
||||
"state": policy_obs,
|
||||
"action": action,
|
||||
"reward": float(reward) if hasattr(reward, "item") else reward,
|
||||
"next_state": policy_next_obs,
|
||||
"done": done,
|
||||
"truncated": truncated,
|
||||
"complementary_info": {
|
||||
"is_intervention": is_intervention,
|
||||
},
|
||||
}
|
||||
|
||||
episode_transitions.append(transition)
|
||||
|
||||
episode_reward += reward
|
||||
step += 1
|
||||
|
||||
obs = next_obs
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Send episode transitions to learner
|
||||
transitions_queue.put_nowait(episode_transitions)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("[ACTOR] Interrupted by user")
|
||||
finally:
|
||||
# Clean up
|
||||
if hasattr(env, "robot") and env.robot.is_connected:
|
||||
env.robot.disconnect()
|
||||
if teleop_device and hasattr(teleop_device, "disconnect"):
|
||||
teleop_device.disconnect()
|
||||
if output_directory is not None:
|
||||
policy_actor.save_pretrained(output_directory)
|
||||
print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
|
||||
|
||||
print("[ACTOR] Actor process finished")
|
||||
|
||||
|
||||
def make_policy_obs(obs, device: torch.device = "cpu"):
|
||||
return {
|
||||
"observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
|
||||
**{
|
||||
f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
|
||||
for k in obs["pixels"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
@@ -1,62 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
|
||||
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
@@ -1,66 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.2"
|
||||
version = "0.3.4"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -62,7 +62,6 @@ dependencies = [
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
@@ -74,15 +73,15 @@ dependencies = [
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
|
||||
"wandb>=0.20.0,<0.23.0",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
"gymnasium>=1.0.0",
|
||||
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
# Support dependencies
|
||||
"deepdiff>=7.0.1,<9.0.0",
|
||||
@@ -97,8 +96,7 @@ dependencies = [
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
@@ -114,26 +112,20 @@ intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
|
||||
# ] # TODO: Currently not supported
|
||||
|
||||
# Policies
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
@@ -143,8 +135,8 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
# Simulation
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
|
||||
metaworld = ["metaworld==3.0.0"]
|
||||
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
|
||||
metaworld = ["metaworld>=3.0.0"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
@@ -157,7 +149,6 @@ all = [
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -242,6 +233,9 @@ exclude_dirs = [
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"src/lerobot/datasets/push_dataset_to_hub",
|
||||
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
|
||||
"src/lerobot/policies/pi0/conversion_scripts",
|
||||
"src/lerobot/scripts/push_dataset_to_hub.py",
|
||||
]
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
@@ -256,8 +250,6 @@ default.extend-ignore-identifiers-re = [
|
||||
"pn",
|
||||
"ser",
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
@@ -296,6 +288,7 @@ ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.envs.*"
|
||||
# Enable type checking only for the envs module
|
||||
ignore_errors = false
|
||||
|
||||
|
||||
@@ -303,22 +296,17 @@ ignore_errors = false
|
||||
# module = "lerobot.utils.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.configs.*"
|
||||
ignore_errors = false
|
||||
|
||||
# extra strictness for configs
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.configs.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.model.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.processor.*"
|
||||
@@ -328,9 +316,9 @@ ignore_errors = false
|
||||
# module = "lerobot.datasets.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.cameras.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
@@ -13,62 +12,47 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.1
|
||||
aiohttp==3.12.15
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
attrs==25.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
av==15.0.0
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
cffi==1.17.1
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.4
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.3.0
|
||||
click==8.2.1
|
||||
# via
|
||||
# uvicorn
|
||||
# flask
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -110,27 +94,27 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
coverage[toml]==7.10.1
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.1.1
|
||||
datasets==3.6.0
|
||||
# via lerobot
|
||||
debugpy==1.8.17
|
||||
debugpy==1.8.15
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.6.1
|
||||
deepdiff==8.5.0
|
||||
# via lerobot
|
||||
diffusers==0.35.2
|
||||
diffusers==0.34.0
|
||||
# via lerobot
|
||||
dill==0.4.0
|
||||
dill==0.3.8
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.34
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -138,45 +122,29 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.8.4
|
||||
dynamixel-sdk==3.7.31
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# via lerobot
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
executing==2.2.0
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.20.0
|
||||
filelock==3.18.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -184,25 +152,24 @@ filelock==3.20.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fonttools==4.60.1
|
||||
flask==3.1.1
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
frozenlist==1.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.9.0
|
||||
fsspec[http]==2025.3.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
glfw==2.9.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -210,79 +177,61 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-hil==0.1.13
|
||||
gym-aloha==0.1.1
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
gym-hil==0.1.10
|
||||
# via lerobot
|
||||
gymnasium==1.2.1
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
hf-xet==1.1.5
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
identify==2.6.12
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
# via imageio
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
iniconfig==2.1.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -290,71 +239,50 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via torch
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
kiwisolver==1.4.8
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
lxml==6.0.0
|
||||
# via dm-control
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.3.7
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -362,25 +290,17 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
@@ -389,43 +309,25 @@ numpy==2.2.6
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# gymnasium-robotics
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# pettingzoo
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
# via gym-pusht
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -435,63 +337,53 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.3
|
||||
pandas==2.3.1
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.5
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
pillow==11.3.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.5.0
|
||||
platformdirs==4.3.8
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.3.0
|
||||
pre-commit==4.2.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
prompt-toolkit==3.0.51
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.4.1
|
||||
propcache==0.3.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -500,17 +392,11 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.1.1
|
||||
psutil==7.0.0
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -519,13 +405,11 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.23
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -540,42 +424,40 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.4.1
|
||||
pyngrok==7.2.12
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==12.0
|
||||
pyobjc-core==11.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==12.0
|
||||
pyobjc-framework-applicationservices==11.1
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==12.0
|
||||
pyobjc-framework-cocoa==11.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==12.0
|
||||
pyobjc-framework-coretext==11.1
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==12.0
|
||||
pyobjc-framework-quartz==11.1
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
pyopengl==3.1.10
|
||||
pyopengl==3.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.5
|
||||
pyparsing==3.2.3
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2-macosx==2.54.2
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -583,14 +465,12 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.2
|
||||
pytest==8.4.1
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
pytest-cov==6.2.1
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -598,73 +478,46 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.1.0
|
||||
pyzmq==27.0.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
regex==2025.7.34
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.5
|
||||
requests==2.32.4
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.1
|
||||
rerun-sdk==0.22.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
safetensors==0.5.3
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -673,12 +526,10 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.42.1
|
||||
sentry-sdk==2.34.1
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
shapely==2.1.1
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -686,106 +537,64 @@ six==1.17.0
|
||||
# python-dateutil
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
# via lerobot
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
tokenizers==0.21.4
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.3.0
|
||||
tomli==2.2.1
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
typing-extensions==4.15.0
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
typing-extensions==4.14.1
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -795,36 +604,22 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
virtualenv==20.32.0
|
||||
# via pre-commit
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
xxhash==3.5.0
|
||||
# via datasets
|
||||
yarl==1.22.0
|
||||
yarl==1.20.1
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
@@ -13,62 +13,47 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.1
|
||||
aiohttp==3.12.15
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
attrs==25.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
av==15.0.0
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
cffi==1.17.1
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.4
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.3.0
|
||||
click==8.2.1
|
||||
# via
|
||||
# uvicorn
|
||||
# flask
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -110,29 +95,27 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
coverage[toml]==7.10.1
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.1.1
|
||||
datasets==3.6.0
|
||||
# via lerobot
|
||||
debugpy==1.8.17
|
||||
debugpy==1.8.15
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
decord==0.6.0
|
||||
deepdiff==8.5.0
|
||||
# via lerobot
|
||||
deepdiff==8.6.1
|
||||
diffusers==0.34.0
|
||||
# via lerobot
|
||||
diffusers==0.35.2
|
||||
# via lerobot
|
||||
dill==0.4.0
|
||||
dill==0.3.8
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.34
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -140,48 +123,31 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.8.4
|
||||
dynamixel-sdk==3.7.31
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# libero
|
||||
# via lerobot
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
evdev==1.9.2
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
executing==2.2.0
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.20.0
|
||||
filelock==3.18.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -189,27 +155,24 @@ filelock==3.20.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flash-attn==2.8.3
|
||||
flask==3.1.1
|
||||
# via lerobot
|
||||
fonttools==4.60.1
|
||||
fonttools==4.59.0
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
frozenlist==1.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.9.0
|
||||
fsspec[http]==2025.3.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
glfw==2.9.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -217,79 +180,61 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-hil==0.1.13
|
||||
gym-aloha==0.1.1
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
gym-hil==0.1.10
|
||||
# via lerobot
|
||||
gymnasium==1.2.1
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
hf-xet==1.1.5
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
identify==2.6.12
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
# via imageio
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
iniconfig==2.1.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -297,71 +242,50 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via torch
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
kiwisolver==1.4.8
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
lxml==6.0.0
|
||||
# via dm-control
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.3.7
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -369,63 +293,42 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# decord
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# gymnasium-robotics
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# pettingzoo
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@@ -463,14 +366,8 @@ nvidia-nvjitlink-cu12==12.6.85
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
# via gym-pusht
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -480,63 +377,53 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.3
|
||||
pandas==2.3.1
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.5
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
pillow==11.3.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.5.0
|
||||
platformdirs==4.3.8
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.3.0
|
||||
pre-commit==4.2.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
prompt-toolkit==3.0.51
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.4.1
|
||||
propcache==0.3.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -545,17 +432,11 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.1.1
|
||||
psutil==7.0.0
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -564,13 +445,11 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.23
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -585,22 +464,20 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.4.1
|
||||
pyngrok==7.2.12
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyopengl==3.1.10
|
||||
pyopengl==3.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.5
|
||||
pyparsing==3.2.3
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2==2.56.5.9235
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -608,14 +485,12 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.2
|
||||
pytest==8.4.1
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
pytest-cov==6.2.1
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -623,75 +498,48 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.1.0
|
||||
pyzmq==27.0.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
regex==2025.7.34
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.5
|
||||
requests==2.32.4
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.1
|
||||
rerun-sdk==0.22.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
safetensors==0.5.3
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -700,12 +548,10 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.42.1
|
||||
sentry-sdk==2.34.1
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
shapely==2.1.1
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -714,109 +560,66 @@ six==1.17.0
|
||||
# python-xlib
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
# via lerobot
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
tokenizers==0.21.4
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.3.0
|
||||
tomli==2.2.1
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
triton==3.3.1
|
||||
# via torch
|
||||
typing-extensions==4.15.0
|
||||
typing-extensions==4.14.1
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -826,36 +629,22 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
virtualenv==20.32.0
|
||||
# via pre-commit
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
xxhash==3.5.0
|
||||
# via datasets
|
||||
yarl==1.22.0
|
||||
yarl==1.20.1
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
|
||||
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -268,7 +268,6 @@ class RemotePolicyConfig:
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
|
||||
@@ -159,10 +159,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
preprocessor_overrides={"device_processor": device_override},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
import numpy as np
|
||||
|
||||
from .configs import CameraConfig, ColorMode
|
||||
|
||||
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""Capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -18,7 +18,7 @@ import abc
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import draccus # type: ignore # TODO: add type stubs for draccus
|
||||
import draccus
|
||||
|
||||
|
||||
class ColorMode(str, Enum):
|
||||
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return str(self.get_choice_name(self.__class__))
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@@ -14,5 +14,3 @@
|
||||
|
||||
from .camera_opencv import OpenCVCamera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
|
||||
|
||||
@@ -25,12 +25,11 @@ from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
@@ -122,7 +121,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -141,7 +140,7 @@ class OpenCVCamera(Camera):
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
@@ -181,14 +180,12 @@ class OpenCVCamera(Camera):
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
|
||||
Applies the specified FPS, width, and height settings to the connected camera.
|
||||
|
||||
This method attempts to set the camera properties via OpenCV. It checks if
|
||||
the camera successfully applied the settings and raises an error if not.
|
||||
FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
|
||||
|
||||
Args:
|
||||
fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
|
||||
fps: The desired frames per second. If None, the setting is skipped.
|
||||
width: The desired capture width. If None, the setting is skipped.
|
||||
height: The desired capture height. If None, the setting is skipped.
|
||||
@@ -202,11 +199,10 @@ class OpenCVCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
@@ -220,56 +216,18 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_width_and_height()
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.fps is None:
|
||||
raise ValueError(f"{self} FPS is not set")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
|
||||
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
# Use math.isclose for robust float comparison
|
||||
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
|
||||
|
||||
def _validate_fourcc(self) -> None:
|
||||
"""Validates and sets the camera's FOURCC code."""
|
||||
|
||||
fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
|
||||
actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
|
||||
|
||||
# Convert actual FOURCC code back to string for comparison
|
||||
actual_fourcc_code_int = int(actual_fourcc_code)
|
||||
actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
if not success or actual_fourcc != self.config.fourcc:
|
||||
logger.warning(
|
||||
f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
|
||||
f"Continuing with default format."
|
||||
)
|
||||
|
||||
def _validate_width_and_height(self) -> None:
|
||||
"""Validates and sets the camera's frame capture width and height."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.capture_width is None or self.capture_height is None:
|
||||
raise ValueError(f"{self} capture_width or capture_height is not set")
|
||||
|
||||
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
|
||||
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
|
||||
|
||||
@@ -300,12 +258,11 @@ class OpenCVCamera(Camera):
|
||||
"""
|
||||
found_cameras_info = []
|
||||
|
||||
targets_to_scan: list[str | int]
|
||||
if platform.system() == "Linux":
|
||||
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
|
||||
targets_to_scan = [str(p) for p in possible_paths]
|
||||
else:
|
||||
targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
|
||||
targets_to_scan = list(range(MAX_OPENCV_INDEX))
|
||||
|
||||
for target in targets_to_scan:
|
||||
camera = cv2.VideoCapture(target)
|
||||
@@ -314,12 +271,6 @@ class OpenCVCamera(Camera):
|
||||
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
default_fps = camera.get(cv2.CAP_PROP_FPS)
|
||||
default_format = camera.get(cv2.CAP_PROP_FORMAT)
|
||||
|
||||
# Get FOURCC code and convert to string
|
||||
default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
|
||||
default_fourcc_code_int = int(default_fourcc_code)
|
||||
default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
camera_info = {
|
||||
"name": f"OpenCV Camera @ {target}",
|
||||
"type": "OpenCV",
|
||||
@@ -327,7 +278,6 @@ class OpenCVCamera(Camera):
|
||||
"backend_api": camera.getBackendName(),
|
||||
"default_stream_profile": {
|
||||
"format": default_format,
|
||||
"fourcc": default_fourcc,
|
||||
"width": default_width,
|
||||
"height": default_height,
|
||||
"fps": default_fps,
|
||||
@@ -339,7 +289,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -367,9 +317,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -382,7 +329,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_frame
|
||||
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
@@ -425,7 +372,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -436,9 +383,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -475,7 +419,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -518,7 +462,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
|
||||
|
||||
@@ -17,8 +17,6 @@ from pathlib import Path
|
||||
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
@@ -35,9 +33,8 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
|
||||
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
|
||||
|
||||
# Advanced configurations with FOURCC format
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
|
||||
OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
|
||||
# Advanced configurations
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
|
||||
```
|
||||
|
||||
Attributes:
|
||||
@@ -49,21 +46,17 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
|
||||
warmup_s: Time reading frames before returning from connect (in seconds)
|
||||
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
- FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
|
||||
- Setting FOURCC can help achieve higher frame rates on some cameras.
|
||||
"""
|
||||
|
||||
index_or_path: int | Path
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
fourcc: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
@@ -78,8 +71,3 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
|
||||
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
|
||||
raise ValueError(
|
||||
f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
|
||||
)
|
||||
|
||||
@@ -16,8 +16,6 @@ from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
@@ -64,7 +62,7 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self):
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
|
||||
@@ -23,17 +23,13 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
CameraManager,
|
||||
)
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -77,7 +73,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -87,17 +83,13 @@ class Reachy2Camera(Camera):
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
)
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
elif self.config.name == "depth":
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
)
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
@@ -139,7 +131,7 @@ class Reachy2Camera(Camera):
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -160,7 +152,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
frame = None
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -187,7 +179,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -198,9 +190,6 @@ class Reachy2Camera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -237,7 +226,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -280,7 +269,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
def disconnect(self):
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
|
||||
@@ -21,12 +21,11 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
|
||||
import pyrealsense2 as rs
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
@@ -133,7 +132,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -151,7 +150,7 @@ class RealSenseCamera(Camera):
|
||||
"""Checks if the camera pipeline is started and streams are active."""
|
||||
return self.rs_pipeline is not None and self.rs_profile is not None
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the RealSense camera specified in the configuration.
|
||||
|
||||
@@ -265,7 +264,7 @@ class RealSenseCamera(Camera):
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
def _configure_rs_pipeline_config(self, rs_config):
|
||||
"""Creates and configures the RealSense pipeline configuration object."""
|
||||
rs.config.enable_device(rs_config, self.serial_number)
|
||||
|
||||
@@ -294,9 +293,6 @@ class RealSenseCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
|
||||
|
||||
if self.rs_profile is None:
|
||||
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
|
||||
|
||||
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
|
||||
|
||||
if self.fps is None:
|
||||
@@ -312,7 +308,7 @@ class RealSenseCamera(Camera):
|
||||
self.width, self.height = actual_width, actual_height
|
||||
self.capture_width, self.capture_height = actual_width, actual_height
|
||||
|
||||
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame (depth) synchronously from the camera.
|
||||
|
||||
@@ -340,9 +336,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -358,7 +351,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
@@ -383,9 +376,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -402,8 +392,8 @@ class RealSenseCamera(Camera):
|
||||
return color_image_processed
|
||||
|
||||
def _postprocess_image(
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
@@ -448,7 +438,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -459,9 +449,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read(timeout_ms=500)
|
||||
@@ -487,7 +474,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
def _stop_read_thread(self):
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
@@ -499,7 +486,7 @@ class RealSenseCamera(Camera):
|
||||
self.stop_event = None
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame data (color) asynchronously.
|
||||
|
||||
@@ -542,7 +529,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
|
||||
@@ -53,14 +53,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
|
||||
|
||||
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import cv2
|
||||
|
||||
if rotation == Cv2Rotation.ROTATE_90:
|
||||
return int(cv2.ROTATE_90_CLOCKWISE)
|
||||
return cv2.ROTATE_90_CLOCKWISE
|
||||
elif rotation == Cv2Rotation.ROTATE_180:
|
||||
return int(cv2.ROTATE_180)
|
||||
return cv2.ROTATE_180
|
||||
elif rotation == Cv2Rotation.ROTATE_270:
|
||||
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
return cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -69,8 +69,8 @@ def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
|
||||
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return int(cv2.CAP_ANY)
|
||||
return cv2.CAP_ANY
|
||||
|
||||
@@ -57,7 +57,7 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self):
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
@@ -22,8 +22,6 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalPipelineConfig:
|
||||
@@ -36,31 +34,25 @@ class EvalPipelineConfig:
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
seed: int | None = 1000
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
logging.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
self.job_name = f"{self.policy.type}"
|
||||
else:
|
||||
self.job_name = (
|
||||
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
)
|
||||
|
||||
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if not self.output_dir:
|
||||
now = dt.datetime.now()
|
||||
|
||||
@@ -16,19 +16,14 @@ import inspect
|
||||
import pkgutil
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from pkgutil import ModuleInfo
|
||||
from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
@@ -65,7 +60,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||
"""Parse plugin-related arguments from command-line arguments.
|
||||
|
||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||
@@ -132,7 +127,7 @@ def load_plugin(plugin_path: str) -> None:
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
|
||||
def iter_namespace(ns_pkg):
|
||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||
|
||||
try:
|
||||
@@ -153,8 +148,6 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
|
||||
|
||||
|
||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||
if args is None:
|
||||
return []
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
@@ -178,8 +171,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
if isinstance(fields_to_filter, str):
|
||||
fields_to_filter = [fields_to_filter]
|
||||
|
||||
filtered_args = [] if args is None else list(args)
|
||||
|
||||
filtered_args = args
|
||||
for field in fields_to_filter:
|
||||
if get_path_arg(field, args):
|
||||
if get_type_arg(field, args):
|
||||
@@ -192,7 +184,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
def wrap(config_path: Path | None = None):
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
@@ -203,9 +195,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
from the CLI '.type' arguments
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn: F) -> F:
|
||||
def wrapper_outer(fn):
|
||||
@wraps(fn)
|
||||
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
def wrapper_inner(*args, **kwargs):
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
argtype = argspec.annotations[argspec.args[0]]
|
||||
if len(args) > 0 and type(args[0]) is argtype:
|
||||
@@ -233,6 +225,6 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
return cast(F, wrapper_inner)
|
||||
return wrapper_inner
|
||||
|
||||
return cast(Callable[[F], F], wrapper_outer)
|
||||
return wrapper_outer
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -34,11 +34,10 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
Base configuration class for policy models.
|
||||
|
||||
@@ -58,12 +57,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||
device: str | None = None # cuda | cpu | mp
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||
push_to_hub: bool = True
|
||||
repo_id: str | None = None
|
||||
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
@@ -74,41 +73,38 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
license: str | None = None
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
pretrained_path: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self):
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logger.warning(
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
def action_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -158,13 +154,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**policy_kwargs: Any,
|
||||
**policy_kwargs,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -172,7 +168,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
@@ -198,9 +194,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import datetime as dt
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -64,18 +63,18 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def validate(self) -> None:
|
||||
def __post_init__(self):
|
||||
self.checkpoint_path = None
|
||||
|
||||
def validate(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
self.policy.pretrained_path = policy_path
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -83,22 +82,14 @@ class TrainPipelineConfig(HubMixin):
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None:
|
||||
raise ValueError(
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
)
|
||||
policy_path = Path(config_path).parent
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.checkpoint_path = policy_path.parent
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
@@ -135,8 +126,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
def to_dict(self) -> dict:
|
||||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
@@ -148,13 +139,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> "TrainPipelineConfig":
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -190,6 +181,4 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
|
||||
@@ -42,11 +42,4 @@ class NormalizationMode(str, Enum):
|
||||
@dataclass
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
class RTCAttentionSchedule(str, Enum):
|
||||
ZEROS = "ZEROS"
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
shape: tuple
|
||||
|
||||
@@ -39,7 +39,6 @@ from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -963,23 +962,28 @@ def _copy_data_with_feature_changes(
|
||||
remove_features: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Copy data while adding or removing features."""
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.meta.root)
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
# Map file paths to episode indices to extract chunk/file indices
|
||||
file_to_episodes: dict[Path, set[int]] = {}
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
file_path = dataset.meta.get_data_file_path(ep_idx)
|
||||
if file_path not in file_to_episodes:
|
||||
file_to_episodes[file_path] = set()
|
||||
file_to_episodes[file_path].add(ep_idx)
|
||||
|
||||
frame_idx = 0
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
|
||||
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
|
||||
|
||||
relative_path = src_path.relative_to(dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
# Get chunk_idx and file_idx from the source file's first episode
|
||||
episodes_in_file = file_to_episodes[src_path]
|
||||
first_ep_idx = min(episodes_in_file)
|
||||
src_ep = dataset.meta.episodes[first_ep_idx]
|
||||
chunk_idx = src_ep["data/chunk_index"]
|
||||
file_idx = src_ep["data/file_index"]
|
||||
|
||||
if remove_features:
|
||||
df = df.drop(columns=remove_features, errors="ignore")
|
||||
@@ -1005,7 +1009,7 @@ def _copy_data_with_feature_changes(
|
||||
df[feature_name] = feature_slice
|
||||
frame_idx = end_idx
|
||||
|
||||
# Write using the same chunk/file structure as source
|
||||
# Write using the preserved chunk_idx and file_idx from source
|
||||
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -430,7 +430,9 @@ class LeRobotDatasetMetadata:
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
video_path = self.root / self.video_path.format(
|
||||
video_key=video_key, chunk_index=0, file_index=0
|
||||
)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
@@ -684,7 +686,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer = None
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -707,8 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
@@ -835,14 +835,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
"""Check if the cached dataset contains all requested episodes."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
# Get available episode indices from cached dataset
|
||||
available_episodes = {
|
||||
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
|
||||
for ep_idx in self.hf_dataset.unique("episode_index")
|
||||
for ep_idx in self.hf_dataset["episode_index"]
|
||||
}
|
||||
|
||||
# Determine requested episodes
|
||||
@@ -854,18 +854,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
requested_episodes = set(self.episodes)
|
||||
|
||||
# Check if all requested episodes are available in cached data
|
||||
if not requested_episodes.issubset(available_episodes):
|
||||
return False
|
||||
|
||||
# Check if all required video files exist
|
||||
if len(self.meta.video_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for vid_key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
return requested_episodes.issubset(available_episodes)
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
@@ -940,26 +929,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
"""
|
||||
Query dataset for indices across keys, skipping video keys.
|
||||
|
||||
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||
|
||||
Args:
|
||||
query_indices: Dict mapping keys to index lists to retrieve
|
||||
|
||||
Returns:
|
||||
Dict with stacked tensors of queried data (video keys excluded)
|
||||
"""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue
|
||||
try:
|
||||
result[key] = torch.stack(self.hf_dataset[key][q_idx])
|
||||
except (KeyError, TypeError, IndexError):
|
||||
result[key] = torch.stack(self.hf_dataset[q_idx][key])
|
||||
return result
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset[q_idx][key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
@@ -1257,7 +1231,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
global_frame_index = 0
|
||||
self._current_file_start_frame = 0
|
||||
# However, if the episodes already exists
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
@@ -1269,7 +1242,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
self._current_file_start_frame = global_frame_index
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.latest_episode
|
||||
@@ -1280,7 +1252,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
|
||||
frames_in_current_file = global_frame_index - self._current_file_start_frame
|
||||
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
|
||||
av_size_per_frame = (
|
||||
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
|
||||
)
|
||||
@@ -1294,7 +1266,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
self._close_writer()
|
||||
self._writer_closed_for_reading = False
|
||||
self._current_file_start_frame = global_frame_index
|
||||
|
||||
ep_dict["data/chunk_index"] = chunk_idx
|
||||
ep_dict["data/file_index"] = file_idx
|
||||
@@ -1501,7 +1472,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj._current_file_start_frame = None
|
||||
# Initialize tracking for incremental recording
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
|
||||
@@ -206,11 +206,6 @@ class ImageTransformsConfig:
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
"affine": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="RandomAffine",
|
||||
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -222,8 +217,6 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ OLD
|
||||
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||
|
||||
NEW
|
||||
videos/CAMERA/chunk-000/file_000.mp4
|
||||
videos/chunk-000/file_000.mp4
|
||||
-------------------------
|
||||
OLD
|
||||
episodes.jsonl
|
||||
|
||||
@@ -342,8 +342,8 @@ def encode_video_frames(
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
dummy_image = Image.open(input_list[0])
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = {}
|
||||
@@ -373,12 +373,11 @@ def encode_video_frames(
|
||||
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
with Image.open(input_data) as input_image:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
input_image = Image.open(input_data).convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
|
||||
@@ -37,16 +37,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@property
|
||||
def package_name(self) -> str:
|
||||
"""Package name to import if environment not found in gym registry"""
|
||||
return f"gym_{self.type}"
|
||||
|
||||
@property
|
||||
def gym_id(self) -> str:
|
||||
"""ID string used in gym.make() to instantiate the environment"""
|
||||
return f"{self.package_name}/{self.task}"
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
||||
@@ -16,10 +16,8 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -34,24 +32,15 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
|
||||
|
||||
def make_env(
|
||||
cfg: EnvConfig | str,
|
||||
n_envs: int = 1,
|
||||
use_async_envs: bool = False,
|
||||
hub_cache_dir: str | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Makes a gym vector environment according to the config or Hub reference.
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
cfg (EnvConfig | str): Either an `EnvConfig` object describing the environment to build locally,
|
||||
or a Hugging Face Hub repository identifier (e.g. `"username/repo"`). In the latter case,
|
||||
the repo must include a Python file (usually `env.py`).
|
||||
cfg (EnvConfig): the config of the environment to instantiate.
|
||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
False.
|
||||
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
|
||||
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
|
||||
Default False — must be set to True to import/exec hub `env.py`.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
@@ -64,21 +53,6 @@ def make_env(
|
||||
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
|
||||
|
||||
"""
|
||||
# if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
|
||||
# simplified: only support hub-provided `make_env`
|
||||
if isinstance(cfg, str):
|
||||
# _download_hub_file will raise the same RuntimeError if trust_remote_code is False
|
||||
repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
|
||||
|
||||
# import and surface clear import errors
|
||||
module = _import_hub_module(local_file, repo_id)
|
||||
|
||||
# call the hub-provided make_env
|
||||
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
|
||||
# normalize the return into {suite: {task_id: vec_env}}
|
||||
return _normalize_hub_result(raw_result)
|
||||
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs` must be at least 1")
|
||||
|
||||
@@ -110,24 +84,17 @@ def make_env(
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
package_name = f"gym_{cfg.type}"
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(cfg.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
||||
)
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
||||
def _make_one():
|
||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import singledispatch
|
||||
@@ -24,7 +22,6 @@ import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -198,132 +195,3 @@ def _(envs: Sequence) -> None:
|
||||
@close_envs.register
|
||||
def _(env: gym.Env) -> None:
|
||||
_close_single_env(env)
|
||||
|
||||
|
||||
# helper to safely load a python file as a module
|
||||
def _load_module_from_path(path: str, module_name: str | None = None):
|
||||
module_name = module_name or f"hub_env_{os.path.basename(path).replace('.', '_')}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Could not load module spec for {module_name} from {path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module
|
||||
|
||||
|
||||
# helper to parse hub string (supports "user/repo", "user/repo@rev", optional path)
|
||||
# examples:
|
||||
# "user/repo" -> will look for env.py at repo root
|
||||
# "user/repo@main:envs/my_env.py" -> explicit revision and path
|
||||
def _parse_hub_url(hub_uri: str):
|
||||
# very small parser: [repo_id][@revision][:path]
|
||||
# repo_id is required (user/repo or org/repo)
|
||||
revision = None
|
||||
file_path = "env.py"
|
||||
if "@" in hub_uri:
|
||||
repo_and_rev, *rest = hub_uri.split(":", 1)
|
||||
repo_id, rev = repo_and_rev.split("@", 1)
|
||||
revision = rev
|
||||
if rest:
|
||||
file_path = rest[0]
|
||||
else:
|
||||
repo_id, *rest = hub_uri.split(":", 1)
|
||||
if rest:
|
||||
file_path = rest[0]
|
||||
return repo_id, revision, file_path
|
||||
|
||||
|
||||
def _download_hub_file(
|
||||
cfg_str: str,
|
||||
trust_remote_code: bool,
|
||||
hub_cache_dir: str | None,
|
||||
) -> tuple[str, str, str, str]:
|
||||
"""
|
||||
Parse `cfg_str` (hub URL), enforce `trust_remote_code`, and return
|
||||
(repo_id, file_path, local_file, revision).
|
||||
"""
|
||||
if not trust_remote_code:
|
||||
raise RuntimeError(
|
||||
f"Refusing to execute remote code from the Hub for '{cfg_str}'. "
|
||||
"Executing hub env modules runs arbitrary Python code from third-party repositories. "
|
||||
"If you trust this repo and understand the risks, call `make_env(..., trust_remote_code=True)` "
|
||||
"and prefer pinning to a specific revision: 'user/repo@<commit-hash>:env.py'."
|
||||
)
|
||||
|
||||
repo_id, revision, file_path = _parse_hub_url(cfg_str)
|
||||
|
||||
try:
|
||||
local_file = hf_hub_download(
|
||||
repo_id=repo_id, filename=file_path, revision=revision, cache_dir=hub_cache_dir
|
||||
)
|
||||
except Exception as e:
|
||||
# fallback to snapshot download
|
||||
snapshot_dir = snapshot_download(repo_id=repo_id, revision=revision, cache_dir=hub_cache_dir)
|
||||
local_file = os.path.join(snapshot_dir, file_path)
|
||||
if not os.path.exists(local_file):
|
||||
raise FileNotFoundError(
|
||||
f"Could not find {file_path} in repository {repo_id}@{revision or 'main'}"
|
||||
) from e
|
||||
|
||||
return repo_id, file_path, local_file, revision
|
||||
|
||||
|
||||
def _import_hub_module(local_file: str, repo_id: str) -> Any:
|
||||
"""
|
||||
Import the downloaded file as a module and surface helpful import error messages.
|
||||
"""
|
||||
module_name = f"hub_env_{repo_id.replace('/', '_')}"
|
||||
try:
|
||||
module = _load_module_from_path(local_file, module_name=module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
missing = getattr(e, "name", None) or str(e)
|
||||
raise ModuleNotFoundError(
|
||||
f"Hub env '{repo_id}:{os.path.basename(local_file)}' failed to import because the dependency "
|
||||
f"'{missing}' is not installed locally.\n\n"
|
||||
) from e
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"Failed to load hub env module '{repo_id}:{os.path.basename(local_file)}'. Import error: {e}\n\n"
|
||||
) from e
|
||||
return module
|
||||
|
||||
|
||||
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
|
||||
"""
|
||||
Ensure module exposes make_env and call it.
|
||||
"""
|
||||
if not hasattr(module, "make_env"):
|
||||
raise AttributeError(
|
||||
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
|
||||
)
|
||||
entry_fn = module.make_env
|
||||
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
|
||||
|
||||
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""
|
||||
Normalize possible return types from hub `make_env` into the mapping:
|
||||
{ suite_name: { task_id: vector_env } }
|
||||
Accepts:
|
||||
- dict (assumed already correct)
|
||||
- gym.vector.VectorEnv
|
||||
- gym.Env (will be wrapped into SyncVectorEnv)
|
||||
"""
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
|
||||
# VectorEnv: use its spec.id if available
|
||||
if isinstance(result, gym.vector.VectorEnv):
|
||||
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
|
||||
return {suite_name: {0: result}}
|
||||
|
||||
# Single Env: wrap into SyncVectorEnv
|
||||
if isinstance(result, gym.Env):
|
||||
vec = gym.vector.SyncVectorEnv([lambda: result])
|
||||
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
|
||||
return {suite_name: {0: vec}}
|
||||
|
||||
raise ValueError(
|
||||
"Hub `make_env` must return either a mapping {suite: {task_id: vec_env}}, "
|
||||
"a gym.vector.VectorEnv, or a single gym.Env."
|
||||
)
|
||||
|
||||
@@ -22,18 +22,18 @@ class RobotKinematics:
|
||||
self,
|
||||
urdf_path: str,
|
||||
target_frame_name: str = "gripper_frame_link",
|
||||
joint_names: list[str] | None = None,
|
||||
joint_names: list[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize placo-based kinematics solver.
|
||||
|
||||
Args:
|
||||
urdf_path (str): Path to the robot URDF file
|
||||
target_frame_name (str): Name of the end-effector frame in the URDF
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
urdf_path: Path to the robot URDF file
|
||||
target_frame_name: Name of the end-effector frame in the URDF
|
||||
joint_names: List of joint names to use for the kinematics solver
|
||||
"""
|
||||
try:
|
||||
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
|
||||
import placo
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"placo is required for RobotKinematics. "
|
||||
@@ -52,7 +52,7 @@ class RobotKinematics:
|
||||
# Initialize frame task for IK
|
||||
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
|
||||
|
||||
def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray:
|
||||
def forward_kinematics(self, joint_pos_deg):
|
||||
"""
|
||||
Compute forward kinematics for given joint configuration given the target frame name in the constructor.
|
||||
|
||||
@@ -77,12 +77,8 @@ class RobotKinematics:
|
||||
return self.robot.get_T_world_frame(self.target_frame_name)
|
||||
|
||||
def inverse_kinematics(
|
||||
self,
|
||||
current_joint_pos: np.ndarray,
|
||||
desired_ee_pose: np.ndarray,
|
||||
position_weight: float = 1.0,
|
||||
orientation_weight: float = 0.01,
|
||||
) -> np.ndarray:
|
||||
self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
|
||||
):
|
||||
"""
|
||||
Compute inverse kinematics using placo solver.
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class OperatingMode(Enum):
|
||||
|
||||
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
|
||||
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
|
||||
# conveyor systems or a system that requires an additional reduction gear. Note that Max Position
|
||||
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
|
||||
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
|
||||
EXTENDED_POSITION = 4
|
||||
|
||||
|
||||
@@ -206,12 +206,8 @@ MODEL_BAUDRATE_TABLE = {
|
||||
# Sign-Magnitude encoding bits
|
||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Homing_Offset": 11,
|
||||
"Goal_Position": 15,
|
||||
"Goal_Velocity": 15,
|
||||
"Goal_Speed": 15,
|
||||
"Present_Position": 15,
|
||||
"Present_Velocity": 15,
|
||||
"Present_Speed": 15,
|
||||
}
|
||||
|
||||
MODEL_ENCODING_TABLE = {
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
@@ -80,11 +79,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0.
|
||||
|
||||
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
|
||||
This ensures the learning rate schedule completes properly even with shorter training runs.
|
||||
"""
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
@@ -92,39 +87,23 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps
|
||||
actual_warmup_steps = self.num_warmup_steps
|
||||
actual_decay_steps = self.num_decay_steps
|
||||
|
||||
if num_training_steps < self.num_decay_steps:
|
||||
# Calculate scaling factor to fit the schedule into the available training steps
|
||||
scale_factor = num_training_steps / self.num_decay_steps
|
||||
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
|
||||
actual_decay_steps = num_training_steps
|
||||
|
||||
logging.info(
|
||||
f"Auto-scaling LR scheduler: "
|
||||
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
|
||||
f"Scaling warmup: {self.num_warmup_steps} → {actual_warmup_steps}, "
|
||||
f"decay: {self.num_decay_steps} → {actual_decay_steps} "
|
||||
f"(scale factor: {scale_factor:.3f})"
|
||||
)
|
||||
del num_training_steps
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (actual_warmup_steps + 1)
|
||||
frac = 1 - current_step / actual_warmup_steps
|
||||
return (1 / (actual_warmup_steps + 1) - 1) * frac + 1
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, actual_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps))
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < actual_warmup_steps:
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
@@ -30,5 +29,4 @@ __all__ = [
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
]
|
||||
|
||||
@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
|
||||
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
||||
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
||||
cross-attending with.
|
||||
encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder).
|
||||
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
||||
Returns:
|
||||
(DS, B, C) tensor of decoder output features.
|
||||
"""
|
||||
|
||||
@@ -30,7 +30,6 @@ from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -38,7 +37,6 @@ from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
@@ -103,10 +101,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
return GrootPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -148,8 +142,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
@@ -207,27 +199,6 @@ def make_pre_post_processors(
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
# GROOT handles normalization in groot_pack_inputs_v3 step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
preprocessor_overrides = {}
|
||||
postprocessor_overrides = {}
|
||||
preprocessor_overrides["groot_pack_inputs_v3"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
}
|
||||
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features["action"].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
"env_action_dim": env_action_dim,
|
||||
}
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -322,14 +293,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
processors = make_groot_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
@@ -340,7 +303,6 @@ def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""
|
||||
Instantiate a policy model.
|
||||
@@ -357,8 +319,6 @@ def make_policy(
|
||||
statistics for normalization layers.
|
||||
env_cfg: Environment configuration used to infer feature shapes and types.
|
||||
One of `ds_meta` or `env_cfg` must be provided.
|
||||
rename_map: Optional mapping of dataset or environment feature keys to match
|
||||
expected policy feature names (e.g., `"left"` → `"camera1"`).
|
||||
|
||||
Returns:
|
||||
An instantiated and device-placed policy model.
|
||||
@@ -400,10 +360,8 @@ def make_policy(
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
if not cfg.output_features:
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
@@ -420,8 +378,4 @@ def make_policy(
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
if not rename_map:
|
||||
validate_visual_features_consistency(cfg, features)
|
||||
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_groot_README.md
|
||||
@@ -1,54 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
b, t = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
@@ -1,370 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: int | None = None,
|
||||
activation_fn: str = "geglu",
|
||||
attention_bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: str | None = None,
|
||||
num_positional_embeddings: int | None = None,
|
||||
ff_inner_dim: int | None = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.norm_type = norm_type
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
if final_dropout:
|
||||
self.final_dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.final_dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
temb: torch.LongTensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# 0. Self-Attention
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: int | None = 1000,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "ada_norm",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: str | None = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
cross_attention_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Timestep encoder
|
||||
self.timestep_encoder = TimestepEncoder(
|
||||
embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype
|
||||
)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(self.config.num_layers):
|
||||
use_self_attn = idx % 2 == 1 and interleave_self_attention
|
||||
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
|
||||
|
||||
all_blocks += [
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
cross_attention_dim=curr_cross_attention_dim,
|
||||
)
|
||||
]
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
# Output blocks
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Encode timesteps
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1 and self.config.interleave_self_attention:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
# Output processing
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
else:
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: int | None = 1000,
|
||||
upcast_attention: bool = False,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: str | None = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for _idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(hidden_states)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if return_all_hidden_states:
|
||||
return hidden_states, all_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
@@ -1,406 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = None
|
||||
|
||||
from lerobot.policies.groot.action_head.action_encoder import (
|
||||
SinusoidalPositionalEncoding,
|
||||
swish,
|
||||
)
|
||||
|
||||
from .cross_attention_dit import DiT, SelfAttentionTransformer
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
b, t, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == b:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, t)
|
||||
else:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
|
||||
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
|
||||
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
|
||||
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
|
||||
backbone_embedding_dim: int = field(
|
||||
default=1536, metadata={"help": "Backbone embedding channel dimension."}
|
||||
)
|
||||
|
||||
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
|
||||
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
|
||||
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
|
||||
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
|
||||
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
|
||||
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
|
||||
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
|
||||
num_timestep_buckets: int = field(
|
||||
default=1000, metadata={"help": "Number of timestep discretization buckets."}
|
||||
)
|
||||
num_inference_timesteps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of inference steps for noise diffusion."},
|
||||
)
|
||||
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
|
||||
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
|
||||
tune_diffusion_model: bool = field(
|
||||
default=True, metadata={"help": "Whether to tune the diffusion model."}
|
||||
)
|
||||
load_pretrained_det_decode_layer_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained detection model."}
|
||||
)
|
||||
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
|
||||
|
||||
freeze_decode_layer: bool = field(default=False)
|
||||
expand_batch: int = field(default=None)
|
||||
use_vlln: bool = field(default=True)
|
||||
|
||||
vl_self_attention_cfg: dict = field(default=None)
|
||||
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
config_class = FlowmatchingActionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlowmatchingActionHeadConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
self.model = DiT(**config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
|
||||
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
self.vl_self_attention = (
|
||||
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
|
||||
)
|
||||
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
|
||||
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
print(f"Tune action head projector: {self.tune_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_projector and not tune_diffusion_model:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Action head trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = backbone_output["backbone_features"]
|
||||
backbone_features = self.vlln(backbone_features)
|
||||
backbone_features = self.vl_self_attention(backbone_features)
|
||||
backbone_output["backbone_features"] = backbone_features
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
# Set frozen modules to eval
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
if self.config.expand_batch is not None:
|
||||
for k, v in backbone_output.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
backbone_output[k] = expanded
|
||||
|
||||
for k, v in action_input.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
action_input[k] = expanded
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
device = vl_embs.device
|
||||
|
||||
# Get embodiment ID.
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
vl_attn_mask = backbone_output.backbone_attention_mask
|
||||
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=vl_attn_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=False, # NOTE (YL): not using flare now
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Slice out only the action portion of pred and target.
|
||||
action_mask = action_input.action_mask
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = loss.sum() / action_mask.sum()
|
||||
output_dict = {
|
||||
"loss": loss,
|
||||
}
|
||||
return BatchFeature(data=output_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Set initial actions as the sampled noise.
|
||||
batch_size = vl_embs.shape[0]
|
||||
device = vl_embs.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=vl_embs.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# Run denoising steps.
|
||||
for t in range(num_steps):
|
||||
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
# Run model forward.
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
|
||||
pred_velocity = pred[:, -self.action_horizon :]
|
||||
|
||||
# Update actions using euler integration.
|
||||
actions = actions + dt * pred_velocity
|
||||
return BatchFeature(data={"action_pred": actions})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
@@ -1,201 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@dataclass
|
||||
class GrootConfig(PreTrainedConfig):
|
||||
"""Configuration for Groot policy wrapper."""
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 64
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
base_model_path: str = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
|
||||
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
|
||||
# Fine-tuning control arguments
|
||||
|
||||
# Whether to fine-tune the llm backbone
|
||||
tune_llm: bool = False
|
||||
|
||||
# Whether to fine-tune the vision tower
|
||||
tune_visual: bool = False
|
||||
|
||||
# Whether to fine-tune the projector
|
||||
tune_projector: bool = True
|
||||
|
||||
# Whether to fine-tune the diffusion model
|
||||
tune_diffusion_model: bool = True
|
||||
|
||||
# LoRA parameters (from groot_finetune_script.py)
|
||||
# Rank for the LORA model. If 0, no LORA will be used.
|
||||
lora_rank: int = 0
|
||||
|
||||
# Alpha value for the LORA model
|
||||
lora_alpha: int = 16
|
||||
|
||||
# Dropout rate for the LORA model
|
||||
lora_dropout: float = 0.1
|
||||
|
||||
# Whether to use the full model for LORA
|
||||
lora_full_model: bool = False
|
||||
|
||||
# Training parameters (matching groot_finetune_script.py)
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
max_steps: int = 10000
|
||||
batch_size: int = 32
|
||||
dataloader_num_workers: int = 8
|
||||
report_to: str = "wandb"
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# groot_repo_path is now optional since we ported the components
|
||||
# No validation needed
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features for Groot."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"Groot policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
else:
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
|
||||
f"Either reduce action dimension or increase max_action_dim in config."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Return optimizer configuration."""
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Return scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
|
||||
num_decay_steps=10000, # Adjust based on training steps
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.optimizer_lr * 0.1,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
"""Return indices for delta observations (None for Groot)."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
return list(range(min(self.chunk_size, 16)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""Return indices for delta rewards (None for Groot)."""
|
||||
return None
|
||||
@@ -1,135 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Eagle25VLConfig(PretrainedConfig):
|
||||
model_type = "eagle_2_5_vl"
|
||||
is_composition = True
|
||||
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-4,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
loss_version="v1",
|
||||
min_dynamic_tiles=1,
|
||||
max_dynamic_tiles=6,
|
||||
mlp_checkpoint=False,
|
||||
initializer_range=0.02,
|
||||
_attn_implementation="flash_attention_2",
|
||||
_attn_implementation_autoset=False,
|
||||
llm_config=None,
|
||||
image_token_index=None,
|
||||
use_pixel_shuffle=True,
|
||||
mlp_connector_layers=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"model_type": "siglip_vision_model"}
|
||||
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {"architectures": ["Qwen2ForCausalLM"]}
|
||||
logger.info(
|
||||
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
if vision_config["model_type"] == "siglip_vision_model":
|
||||
self.vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
|
||||
|
||||
if text_config["architectures"][0] == "LlamaForCausalLM":
|
||||
self.text_config = LlamaConfig(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
|
||||
self.text_config = Qwen2Config(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
|
||||
self.text_config = Qwen3Config(**text_config)
|
||||
else:
|
||||
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.mlp_checkpoint = mlp_checkpoint
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.loss_version = loss_version
|
||||
self.initializer_range = initializer_range
|
||||
self.min_dynamic_tiles = min_dynamic_tiles
|
||||
self.max_dynamic_tiles = max_dynamic_tiles
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self._attn_implementation = _attn_implementation
|
||||
self._attn_implementation_autoset = _attn_implementation_autoset
|
||||
self.image_token_index = image_token_index
|
||||
self.use_pixel_shuffle = use_pixel_shuffle
|
||||
self.mlp_connector_layers = mlp_connector_layers
|
||||
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
|
||||
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["pad2square"] = self.pad2square
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["min_dynamic_tiles"] = self.min_dynamic_tiles
|
||||
output["max_dynamic_tiles"] = self.max_dynamic_tiles
|
||||
output["tie_word_embeddings"] = self.tie_word_embeddings
|
||||
output["_attn_implementation"] = self._attn_implementation
|
||||
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
|
||||
output["use_pixel_shuffle"] = self.use_pixel_shuffle
|
||||
output["mlp_connector_layers"] = self.mlp_connector_layers
|
||||
return output
|
||||
@@ -1,504 +0,0 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from typing import Optional
|
||||
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
||||
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_flat_list_of_images,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from transformers.image_utils import pil_torch_interpolation_mapping
|
||||
else:
|
||||
from torchvision.transforms import functional as F # noqa: N812
|
||||
|
||||
|
||||
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
|
||||
"""Crop the given numpy array.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
|
||||
left (int): The left coordinate of the crop box.
|
||||
top (int): The top coordinate of the crop box.
|
||||
right (int): The right coordinate of the crop box.
|
||||
bottom (int): The bottom coordinate of the crop box.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cropped image.
|
||||
"""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
|
||||
|
||||
if img.ndim not in [2, 3]:
|
||||
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
|
||||
|
||||
img_height = img.shape[1]
|
||||
img_width = img.shape[2]
|
||||
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
||||
raise ValueError("Crop coordinates out of bounds")
|
||||
|
||||
if top >= bottom or left >= right:
|
||||
raise ValueError("Invalid crop coordinates")
|
||||
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
pad_during_tiling: bool | None
|
||||
do_pad: bool | None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method. Not used for processing videos.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 448, "width": 448}
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = True
|
||||
max_dynamic_tiles = 12
|
||||
min_dynamic_tiles = 1
|
||||
use_thumbnail = True
|
||||
pad_during_tiling = False
|
||||
valid_kwargs = Eagle25VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
max_dynamic_tiles (`int`, *optional*):
|
||||
The maximum number of dynamic tiles to use for processing high resolution images.
|
||||
min_dynamic_tiles (`int`, *optional*):
|
||||
The minimum number of dynamic tiles to use for processing high resolution images.
|
||||
use_thumbnail (`bool`, *optional*):
|
||||
Whether to use a thumbnail for processing high resolution images.
|
||||
pad_during_tiling (`bool`, *optional*):
|
||||
Whether to pad the image during tiling.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
|
||||
# NOTE(YL): we will overload the preprocess method to add the image_flags
|
||||
# def preprocess(
|
||||
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
|
||||
# ) -> BatchFeature:
|
||||
# return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
expected_ndims: int = 3,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
expected_ndims (`int`, *optional*, defaults to 3):
|
||||
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_flat_list_of_images(images)
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
target_resolution: tuple,
|
||||
interpolation: "F.InterpolationMode",
|
||||
input_data_format: ChannelDimension,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image ("torch.Tensor"):
|
||||
The input image.
|
||||
target_resolution (tuple):
|
||||
The target resolution (height, width) of the image.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
"torch.Tensor": The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
|
||||
return resized_image
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
"""
|
||||
previous version mainly focus on ratio.
|
||||
We also consider area ratio here.
|
||||
"""
|
||||
best_factor = float("-inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
|
||||
"""
|
||||
new area > 60% of original image area is enough.
|
||||
"""
|
||||
factor_based_on_area_n_ratio = min(
|
||||
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
|
||||
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
|
||||
|
||||
if factor_based_on_area_n_ratio > best_factor:
|
||||
best_factor = factor_based_on_area_n_ratio
|
||||
best_ratio = ratio
|
||||
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: "F.InterpolationMode",
|
||||
pad_during_tiling: bool,
|
||||
) -> list["torch.Tensor"]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = tile_size * target_aspect_ratio[0]
|
||||
target_height = tile_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if pad_during_tiling:
|
||||
resized_image = self._resize_for_patching(
|
||||
image,
|
||||
(target_height, target_width),
|
||||
interpolation=interpolation,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
padded_image = self._pad_for_patching(
|
||||
resized_image,
|
||||
(target_height, target_width),
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
image_used_to_split = padded_image
|
||||
else:
|
||||
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
|
||||
|
||||
processed_tiles = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // tile_size)) * tile_size,
|
||||
(i // (target_width // tile_size)) * tile_size,
|
||||
((i % (target_width // tile_size)) + 1) * tile_size,
|
||||
((i // (target_width // tile_size)) + 1) * tile_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
|
||||
processed_tiles.append(split_img)
|
||||
assert len(processed_tiles) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_tiles) != 1:
|
||||
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
|
||||
processed_tiles.append(thumbnail_img)
|
||||
|
||||
return processed_tiles
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[torch.Tensor]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
|
||||
Returns:
|
||||
List[`torch.Tensor`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
do_pad: bool,
|
||||
return_tensors: str | TensorType | None,
|
||||
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
else:
|
||||
size_tuple = (size.shortest_edge, size.shortest_edge)
|
||||
|
||||
# Determine the patch size
|
||||
if crop_size and crop_size.height:
|
||||
tile_size = crop_size.height
|
||||
elif size and size.height:
|
||||
tile_size = size.height
|
||||
else:
|
||||
tile_size = size.shortest_edge
|
||||
|
||||
for image in images:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
min_num=min_dynamic_tiles,
|
||||
max_num=max_dynamic_tiles,
|
||||
size=size_tuple,
|
||||
tile_size=tile_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
interpolation=interpolation,
|
||||
pad_during_tiling=pad_during_tiling,
|
||||
)
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
if do_center_crop:
|
||||
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean,
|
||||
image_std,
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(
|
||||
processed_image_patches_grouped, grouped_image_patches_index
|
||||
)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.append(processed_image_patches)
|
||||
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
|
||||
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
|
||||
)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if videos is not None:
|
||||
videos = self._prepare_image_like_inputs(
|
||||
images=videos,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
resample = kwargs.pop("resample", self.resample)
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample]
|
||||
if isinstance(resample, PILImageResampling | int)
|
||||
else resample
|
||||
)
|
||||
|
||||
# Filter kwargs to only include those accepted by _preprocess
|
||||
valid_preprocess_kwargs = {
|
||||
"do_resize",
|
||||
"size",
|
||||
"max_dynamic_tiles",
|
||||
"min_dynamic_tiles",
|
||||
"use_thumbnail",
|
||||
"pad_during_tiling",
|
||||
"interpolation",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_pad",
|
||||
"return_tensors",
|
||||
"pad_size",
|
||||
"disable_grouping",
|
||||
}
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
|
||||
if images is not None:
|
||||
return self._preprocess(images, **filtered_kwargs)
|
||||
elif videos is not None:
|
||||
return self._preprocess(videos, **filtered_kwargs)
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLImageProcessorFast"]
|
||||
@@ -1,395 +0,0 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as cp
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
||||
from transformers.utils import add_start_docstrings, logging
|
||||
|
||||
from .configuration_eagle2_5_vl import Eagle25VLConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
|
||||
EAGLE2_5_VL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Eagle25VLConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
|
||||
EAGLE2_5_VL_START_DOCSTRING,
|
||||
)
|
||||
class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
config_class = Eagle25VLConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Qwen2DecoderLayer",
|
||||
"LlamaDecoderLayer",
|
||||
"Siglip2EncoderLayer",
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear | nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
|
||||
config_class = Eagle25VLConfig
|
||||
|
||||
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
|
||||
super().__init__(config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
if config.use_pixel_shuffle:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
|
||||
else:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2)
|
||||
|
||||
self.select_layer = config.select_layer
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.loss_version = config.loss_version
|
||||
self.mlp_checkpoint = config.mlp_checkpoint
|
||||
self.use_pixel_shuffle = config.use_pixel_shuffle
|
||||
self.mlp_connector_layers = config.mlp_connector_layers
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
|
||||
if vision_model is not None:
|
||||
self.vision_model = vision_model
|
||||
else:
|
||||
if config.vision_config.model_type == "siglip_vision_model":
|
||||
config.vision_config._attn_implementation = "flash_attention_2"
|
||||
self.vision_model = SiglipVisionModel(config.vision_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
|
||||
|
||||
if language_model is not None:
|
||||
self.language_model = language_model
|
||||
else:
|
||||
if config.text_config.architectures[0] == "LlamaForCausalLM":
|
||||
self.language_model = LlamaForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
|
||||
raise NotImplementedError("Phi3 is not implemented.")
|
||||
# self.language_model = Phi3ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
assert config.text_config._attn_implementation == "flash_attention_2", (
|
||||
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
|
||||
)
|
||||
self.language_model = Qwen2ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
||||
self.language_model = Qwen3ForCausalLM(config.text_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
if config.mlp_connector_layers == 2:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size, llm_hidden_size),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
|
||||
|
||||
self.image_token_index = config.image_token_index
|
||||
self.neftune_alpha = None
|
||||
|
||||
if config.use_backbone_lora:
|
||||
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
||||
|
||||
self.use_llm_lora = config.use_llm_lora
|
||||
if config.use_llm_lora:
|
||||
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
||||
|
||||
self.check_forward_kwargs()
|
||||
|
||||
def check_forward_kwargs(self):
|
||||
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
|
||||
# has special handling for functions with **kwargs parameters that would affect
|
||||
# how our model is processed during training and inference.
|
||||
forward_params = inspect.signature(self.forward).parameters
|
||||
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
|
||||
|
||||
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.out_proj",
|
||||
"mlp.fc1",
|
||||
"mlp.fc2",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
self.vision_model.print_trainable_parameters()
|
||||
|
||||
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.o_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.up_proj",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.language_model = get_peft_model(self.language_model, lora_config)
|
||||
self.language_model.enable_input_require_grads()
|
||||
self.language_model.print_trainable_parameters()
|
||||
self.use_llm_lora = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
image_flags: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_tiles_list: list[torch.Tensor] | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
if image_flags is not None:
|
||||
image_flags = image_flags.view(-1)
|
||||
vit_embeds = vit_embeds[image_flags == 1]
|
||||
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.image_token_index
|
||||
try:
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
|
||||
except Exception as e:
|
||||
vit_embeds = vit_embeds.reshape(-1, c)
|
||||
print(
|
||||
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
|
||||
f"vit_embeds.shape={vit_embeds.shape}"
|
||||
)
|
||||
n_token = selected.sum()
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
|
||||
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
)
|
||||
if hasattr(vit_embeds, "last_hidden_state"):
|
||||
vit_embeds = vit_embeds.last_hidden_state
|
||||
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
|
||||
if self.use_pixel_shuffle:
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
||||
|
||||
if self.mlp_checkpoint and vit_embeds.requires_grad:
|
||||
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
|
||||
else:
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
|
||||
return vit_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
input_ids: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
visual_features: torch.FloatTensor | None = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
image_sizes: list[tuple[int, int]] | None = None,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
if pixel_values is not None:
|
||||
if visual_features is not None:
|
||||
vit_embeds = visual_features
|
||||
else:
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.config.image_token_index
|
||||
assert selected.sum() != 0
|
||||
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
else:
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if "use_cache" not in generate_kwargs:
|
||||
generate_kwargs["use_cache"] = True
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
@@ -1,518 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Eagle25VL.
|
||||
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 256
|
||||
|
||||
|
||||
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||
if pil_image.mode == "RGBA":
|
||||
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||
return white_background
|
||||
else:
|
||||
return pil_image.convert("RGB")
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
|
||||
image = ele["image"] if "image" in ele else ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(image, stream=True, timeout=10)
|
||||
image_obj = Image.open(BytesIO(response.content))
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = to_rgb(image_obj)
|
||||
if "scale_factor" in ele:
|
||||
scale_factor = ele["scale_factor"]
|
||||
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
# see processing_utils.ProcessingKwargs documentation for usage.
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
"videos_kwargs": {"max_dynamic_tiles": 1},
|
||||
}
|
||||
|
||||
|
||||
class Eagle25VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
|
||||
|
||||
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
|
||||
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
num_image_tokens (`int`, *optional*):
|
||||
Number of image tokens for one imagethat will be returned by vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Should be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"num_image_tokens",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"images_kwargs",
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<IMG_CONTEXT>", # nosec: B107
|
||||
video_token="<IMG_CONTEXT>", # nosec: B107
|
||||
tokens_per_tile=256,
|
||||
image_placeholder="image",
|
||||
video_placeholder="video",
|
||||
image_start_token="<img>",
|
||||
image_end_token="</img>",
|
||||
**kwargs,
|
||||
):
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
self.image_placeholder = image_placeholder
|
||||
self.video_placeholder = video_placeholder
|
||||
self.tokens_per_tile = tokens_per_tile
|
||||
self.image_start_token = image_start_token
|
||||
self.image_end_token = image_end_token
|
||||
if "auto_map" in kwargs:
|
||||
self.auto_map = kwargs["auto_map"]
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def replace_media_placeholder(
|
||||
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
|
||||
):
|
||||
num_of_images_in_this_sample = 0
|
||||
num_of_videos_in_this_sample = 0
|
||||
# Regular expression pattern to match formats like <image-1> or <video-2>
|
||||
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
|
||||
unified_frame_list = []
|
||||
|
||||
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
# )
|
||||
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
# )
|
||||
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
|
||||
# "use_thumbnail", self.image_processor.use_thumbnail
|
||||
# )
|
||||
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
)
|
||||
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
)
|
||||
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
|
||||
"use_thumbnail", self.image_processor.use_thumbnail
|
||||
)
|
||||
|
||||
tile_size = self.image_processor.size.get("height", 448)
|
||||
|
||||
# Function to replace tags in a single text
|
||||
def replace_in_text(text):
|
||||
# repl callback function for each match replacement operation
|
||||
def repl(match):
|
||||
nonlocal unified_frame_list
|
||||
nonlocal num_of_images_in_this_sample
|
||||
nonlocal num_of_videos_in_this_sample
|
||||
media_type = match.group(1) # 'image' or 'video'
|
||||
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
|
||||
# Select the corresponding path based on media type
|
||||
idx_mapper = {
|
||||
0: "first",
|
||||
1: "second",
|
||||
2: "third",
|
||||
3: "fourth",
|
||||
4: "fifth",
|
||||
5: "sixth",
|
||||
6: "seventh",
|
||||
7: "eighth",
|
||||
8: "ninth",
|
||||
9: "tenth",
|
||||
}
|
||||
if media_type == "image":
|
||||
image_inputs = self.image_processor(
|
||||
images=[image_list[idx_in_list]],
|
||||
videos=None,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
unified_frame_list.append(image_inputs)
|
||||
num_of_images_in_this_sample += 1
|
||||
|
||||
elif media_type == "video":
|
||||
video_inputs = self.image_processor(
|
||||
images=None,
|
||||
videos=[video_list[idx_in_list]],
|
||||
**output_kwargs["videos_kwargs"],
|
||||
)
|
||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||
image_sizes = video_inputs["image_sizes"]
|
||||
if timestamps_list is not None and -1 not in timestamps_list:
|
||||
frame_timestamps = timestamps_list[idx_in_list]
|
||||
else:
|
||||
frame_timestamps = None
|
||||
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
|
||||
|
||||
num_of_tiles_each_frame = [
|
||||
self.get_number_tiles_based_on_image_size(
|
||||
image_size,
|
||||
video_min_dynamic_tiles,
|
||||
video_max_dynamic_tiles,
|
||||
video_use_thumbnail,
|
||||
tile_size,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
|
||||
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
|
||||
)
|
||||
|
||||
if frame_timestamps is not None:
|
||||
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
|
||||
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
|
||||
)
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
else:
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
|
||||
if sampled_fps is not None:
|
||||
special_placeholder = (
|
||||
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
|
||||
+ "".join(special_placeholder)
|
||||
)
|
||||
else:
|
||||
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
|
||||
special_placeholder
|
||||
)
|
||||
unified_frame_list.append(video_inputs)
|
||||
num_of_videos_in_this_sample += 1
|
||||
else:
|
||||
raise ValueError(f"Unknown media type: {media_type}")
|
||||
return special_placeholder
|
||||
|
||||
return pattern.sub(repl, text)
|
||||
|
||||
text = replace_in_text(text)
|
||||
if len(unified_frame_list) > 0:
|
||||
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
|
||||
else:
|
||||
pixel_values = None
|
||||
image_sizes = None
|
||||
return (
|
||||
text,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio=None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Eagle25VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text_list = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
text_list = text
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if videos is None:
|
||||
videos = []
|
||||
|
||||
pixel_values_list = []
|
||||
image_sizes_list = []
|
||||
new_sample_list = []
|
||||
image_start_idx = 0
|
||||
video_start_idx = 0
|
||||
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
|
||||
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
|
||||
for sample in text_list:
|
||||
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
|
||||
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
|
||||
(
|
||||
sample,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
) = self.replace_media_placeholder(
|
||||
sample,
|
||||
images[image_start_idx:],
|
||||
videos[video_start_idx:],
|
||||
timestamps_list,
|
||||
fps_list,
|
||||
**output_kwargs,
|
||||
)
|
||||
new_sample_list.append(sample)
|
||||
if pixel_values is not None:
|
||||
pixel_values_list.append(pixel_values)
|
||||
image_sizes_list.append(image_sizes)
|
||||
image_start_idx += num_of_images_in_this_sample
|
||||
video_start_idx += num_of_videos_in_this_sample
|
||||
|
||||
if len(pixel_values_list) > 0:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_list),
|
||||
"image_sizes": torch.cat(image_sizes_list),
|
||||
}
|
||||
else:
|
||||
image_inputs = {}
|
||||
video_inputs = {}
|
||||
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
||||
|
||||
def get_number_tiles_based_on_image_size(
|
||||
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of tiles based on the image size.
|
||||
"""
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if use_thumbnail and tiles_num > 1:
|
||||
tiles_num += 1
|
||||
return tiles_num
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# override to save video-config in a separate config file
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
outputs = super().save_pretrained(save_directory, **kwargs)
|
||||
return outputs
|
||||
|
||||
# override to load video-config from a separate config file
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
||||
if isinstance(processor, tuple):
|
||||
processor = processor[0]
|
||||
return processor
|
||||
|
||||
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def process_vision_info(
|
||||
self,
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
return_video_kwargs: bool = False,
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
|
||||
vision_infos = self.extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
video_sample_fps_list = []
|
||||
video_timestamps_list = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
if return_video_kwargs:
|
||||
return (
|
||||
image_inputs,
|
||||
video_inputs,
|
||||
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
|
||||
)
|
||||
return image_inputs, video_inputs
|
||||
|
||||
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if (
|
||||
"image" in ele
|
||||
or "image_url" in ele
|
||||
or "video" in ele
|
||||
or ele["type"] in ("image", "image_url", "video")
|
||||
):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLProcessor"]
|
||||
@@ -1,376 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
from lerobot.policies.groot.action_head.flow_matching_action_head import (
|
||||
FlowmatchingActionHead,
|
||||
FlowmatchingActionHeadConfig,
|
||||
)
|
||||
from lerobot.policies.groot.utils import ensure_eagle_cache_ready
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
class EagleBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
|
||||
project_to_dim: int = 1536,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tune_llm: whether to tune the LLM model (default: True)
|
||||
tune_visual: whether to tune the visual model (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
|
||||
|
||||
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
|
||||
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
try:
|
||||
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
|
||||
|
||||
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
|
||||
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
if project_to_dim is not None:
|
||||
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
|
||||
else:
|
||||
self.eagle_linear = torch.nn.Identity()
|
||||
|
||||
# needed since we don't use these layers. Also saves compute
|
||||
while len(self.eagle_model.language_model.model.layers) > select_layer:
|
||||
self.eagle_model.language_model.model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual)
|
||||
|
||||
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.eagle_model.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.eagle_model.vision_model.requires_grad_(False)
|
||||
self.eagle_model.mlp1.requires_grad_(False)
|
||||
print(f"Tune backbone llm: {self.tune_llm}")
|
||||
print(f"Tune backbone visual: {self.tune_visual}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_llm and not tune_visual:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Backbone trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No backbone trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if self.eagle_model.language_model and not self.tune_llm:
|
||||
self.eagle_model.language_model.eval()
|
||||
if self.eagle_model.vision_model and not self.tune_visual:
|
||||
self.eagle_model.vision_model.eval()
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
eagle_prefix = "eagle_"
|
||||
eagle_input = {
|
||||
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
|
||||
}
|
||||
del eagle_input["image_sizes"]
|
||||
|
||||
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
|
||||
eagle_features = eagle_output.hidden_states[self.select_layer]
|
||||
|
||||
eagle_features = self.eagle_linear(eagle_features)
|
||||
return eagle_features, eagle_input["attention_mask"]
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
|
||||
|
||||
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
|
||||
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
|
||||
if self.training and self.tune_visual:
|
||||
dummy_term = torch.tensor(
|
||||
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
|
||||
)
|
||||
for param in self.eagle_model.vision_model.parameters():
|
||||
if param.requires_grad:
|
||||
dummy_term = dummy_term + 0.0 * param.sum()
|
||||
eagle_embeds = eagle_embeds + dummy_term
|
||||
|
||||
return BatchFeature(
|
||||
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
|
||||
) # [B, T2, hidden_size]
|
||||
|
||||
|
||||
BACKBONE_FEATURE_KEY = "backbone_features"
|
||||
ACTION_KEY = "action_pred"
|
||||
LOSS_KEY = "loss"
|
||||
ERROR_MSG = "Error: unexpected input/output"
|
||||
N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@dataclass
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
|
||||
|
||||
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
|
||||
|
||||
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
|
||||
|
||||
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
|
||||
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
# real model
|
||||
class GR00TN15(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
config_class = GR00TN15Config
|
||||
"""
|
||||
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
|
||||
here n is variable and can be e.g. time, 1 or user specified
|
||||
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
|
||||
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN15Config,
|
||||
local_model_path: str,
|
||||
):
|
||||
assert isinstance(config.backbone_cfg, dict)
|
||||
assert isinstance(config.action_head_cfg, dict)
|
||||
|
||||
super().__init__(config)
|
||||
self.local_model_path = local_model_path
|
||||
|
||||
self.backbone = EagleBackbone(**config.backbone_cfg)
|
||||
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
|
||||
self.action_head = FlowmatchingActionHead(action_head_cfg)
|
||||
|
||||
self.action_horizon = config.action_horizon
|
||||
self.action_dim = config.action_dim
|
||||
self.compute_dtype = config.compute_dtype
|
||||
|
||||
def validate_inputs(self, inputs):
|
||||
# NOTE -- this should be handled internally by the model
|
||||
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
|
||||
|
||||
detected_error = False
|
||||
error_msg = ERROR_MSG
|
||||
if "action" in inputs:
|
||||
action = inputs["action"]
|
||||
# In inference, action may be omitted or None; validate only when it's a tensor.
|
||||
if action is None:
|
||||
pass # allow None during inference
|
||||
elif isinstance(action, torch.Tensor):
|
||||
shape_ok = (
|
||||
len(action.shape) == 3
|
||||
and action.shape[1] == self.action_horizon
|
||||
and action.shape[2] == self.action_dim
|
||||
)
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{action.shape=}"
|
||||
detected_error = True
|
||||
else:
|
||||
# Unexpected non-tensor type provided for action
|
||||
error_msg += f"\nInvalid type for action: {type(action)}"
|
||||
detected_error = True
|
||||
|
||||
if "video" in inputs:
|
||||
video = inputs["video"]
|
||||
type_ok = isinstance(video, np.ndarray)
|
||||
dtype_ok = video.dtype == np.uint8
|
||||
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
|
||||
if not type_ok:
|
||||
error_msg += f"\n{type(video)=}"
|
||||
detected_error = True
|
||||
if not dtype_ok:
|
||||
error_msg += f"\n{video.dtype=}"
|
||||
detected_error = True
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{video.shape=}"
|
||||
detected_error = True
|
||||
|
||||
if detected_error:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
|
||||
fail_backbone = (
|
||||
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
|
||||
)
|
||||
|
||||
if fail_backbone:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
|
||||
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
|
||||
(
|
||||
LOSS_KEY in action_head_outputs and is_training
|
||||
) # there might not be an action prediction during training
|
||||
or (
|
||||
ACTION_KEY in action_head_outputs
|
||||
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
|
||||
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
|
||||
)
|
||||
)
|
||||
|
||||
if fail_action_head:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
|
||||
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
|
||||
error_msg += f"\n{self.action_horizon=}"
|
||||
error_msg += f"\n{self.action_dim=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
|
||||
return action_head_outputs
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
|
||||
return action_head_outputs
|
||||
|
||||
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
|
||||
self.validate_inputs(inputs)
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_maybe_dtype(x):
|
||||
# Cast floating tensors to a memory-efficient compute dtype when requested.
|
||||
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
|
||||
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
if getattr(self, "compute_dtype", None) == "bfloat16":
|
||||
return x.to(self.device, dtype=torch.bfloat16)
|
||||
# Fallback: preserve previous behavior if not using bf16 compute
|
||||
return x.to(self.device, dtype=self.action_head.dtype)
|
||||
# Non-floating tensors: move device only
|
||||
return x.to(self.device)
|
||||
|
||||
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
|
||||
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
|
||||
return backbone_inputs, action_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
|
||||
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
|
||||
print(f"Tune backbone vision tower: {tune_visual}")
|
||||
print(f"Tune backbone LLM: {tune_llm}")
|
||||
print(f"Tune action head projector: {tune_projector}")
|
||||
print(f"Tune action head DiT: {tune_diffusion_model}")
|
||||
|
||||
# get the current model path being downloaded
|
||||
try:
|
||||
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
|
||||
# saved in ~/.cache/huggingface/hub/
|
||||
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
|
||||
# HFValidationError, RepositoryNotFoundError
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
print(
|
||||
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
|
||||
)
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path, local_model_path=local_model_path, **kwargs
|
||||
)
|
||||
|
||||
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
|
||||
)
|
||||
return pretrained_model
|
||||
@@ -1,198 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T components where possible
|
||||
without porting their code. The intent is to:
|
||||
|
||||
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
|
||||
- Optionally align action horizon similar to gr00t_finetune.py
|
||||
- Expose predict_action via GR00T model.get_action
|
||||
- Provide a training forward that can call the GR00T model forward if batch
|
||||
structure matches.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1 import GR00TN15
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class GrootPolicy(PreTrainedPolicy):
|
||||
"""Wrapper around external Groot model for LeRobot integration."""
|
||||
|
||||
name = "groot"
|
||||
config_class = GrootConfig
|
||||
|
||||
def __init__(self, config: GrootConfig):
|
||||
"""Initialize Groot policy wrapper."""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Initialize GR00T model using ported components
|
||||
self._groot_model = self._create_groot_model()
|
||||
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T model using Isaac-GR00T API.
|
||||
|
||||
This is only called when creating a NEW policy (not when loading from checkpoint).
|
||||
|
||||
Steps (delegating to Isaac-GR00T):
|
||||
1) Download and load pretrained model via GR00TN15.from_pretrained
|
||||
2) Align action horizon with data_config if provided
|
||||
"""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
model = GR00TN15.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.base_model_path,
|
||||
tune_llm=self.config.tune_llm,
|
||||
tune_visual=self.config.tune_visual,
|
||||
tune_projector=self.config.tune_projector,
|
||||
tune_diffusion_model=self.config.tune_diffusion_model,
|
||||
)
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Training forward pass.
|
||||
|
||||
Delegates to Isaac-GR00T model.forward when inputs are compatible.
|
||||
"""
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.forward(groot_inputs)
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch
|
||||
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
original_action_dim = self.config.output_features["action"].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select single action from action queue."""
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
@@ -1,664 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, ProcessorMixin
|
||||
else:
|
||||
AutoProcessor = None
|
||||
ProcessorMixin = object
|
||||
|
||||
from lerobot.configs.types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
PolicyFeature,
|
||||
)
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
HF_LEROBOT_HOME,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
# Defaults for Eagle processor locations
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
def make_groot_pre_post_processors(
|
||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create preprocessor and postprocessor for Groot policy.
|
||||
|
||||
This creates a processing pipeline that transforms LeRobot data format into
|
||||
the format expected by Isaac-GR00T models:
|
||||
|
||||
Preprocessing steps:
|
||||
1. Optional key renaming (dataset-specific key mapping)
|
||||
2. Add batch dimension to unbatched data
|
||||
3. Pack video/state/action/language/embodiment and apply optional min-max normalization before padding
|
||||
4. Encode video+language with Eagle VLM into intermediate eagle_content
|
||||
5. Collate eagle_content into batched eagle_* tensors
|
||||
6. Move tensors to device (GPU)
|
||||
|
||||
NOTE: We optionally apply min-max normalization to STATE and ACTION using
|
||||
dataset-provided statistics prior to padding, mapping values to [-1, 1].
|
||||
This mirrors SO100-style preprocessing and keeps scales consistent with GR00T.
|
||||
|
||||
Args:
|
||||
config: Groot configuration containing data_config, embodiment_tag, etc.
|
||||
dataset_stats: Optional per-key min/max statistics for normalization before padding.
|
||||
|
||||
Returns:
|
||||
Tuple of (preprocessor, postprocessor) pipelines
|
||||
"""
|
||||
|
||||
# Get horizon/dimension parameters from config
|
||||
# These should match the config used for the pretrained model
|
||||
# Default values match most GR00T configs (state_horizon=1, action_horizon=16)
|
||||
state_horizon = 1
|
||||
# CRITICAL: Pretrained GR00T models use action_horizon=16 max!
|
||||
# The model architecture hardcodes this limit
|
||||
action_horizon = min(config.chunk_size, 16)
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
|
||||
# Pass raw dataset_stats; normalization will occur inside pack step before padding
|
||||
padded_stats = dataset_stats or {}
|
||||
|
||||
# Define feature specs for optional normalization steps
|
||||
_features: dict[str, PolicyFeature] = {
|
||||
# Observation features (only add those we may normalize)
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
|
||||
# Action feature
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
|
||||
}
|
||||
|
||||
# Normalize STATE and ACTION with min_max (SO100-like default)
|
||||
_norm_map = {
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Determine env action dimension from config (simple, object-like PolicyFeature)
|
||||
try:
|
||||
env_action_dim = int(config.output_features["action"].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
# 1. Rename keys if needed (e.g., dataset-specific camera names)
|
||||
# Leave empty for now - add mappings if your dataset uses different key names
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
# 2. Add batch dimension for single samples
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# 3. Pack video/state/action/language/embodiment; apply optional min-max normalization before padding
|
||||
GrootPackInputsStep(
|
||||
state_horizon=state_horizon,
|
||||
action_horizon=action_horizon,
|
||||
max_state_dim=max_state_dim,
|
||||
max_action_dim=max_action_dim,
|
||||
language_key="task",
|
||||
formalize_language=False,
|
||||
embodiment_tag=config.embodiment_tag,
|
||||
normalize_min_max=True,
|
||||
stats=padded_stats,
|
||||
),
|
||||
# 4. Eagle encode (creates eagle_content)
|
||||
GrootEagleEncodeStep(
|
||||
tokenizer_assets_repo=config.tokenizer_assets_repo,
|
||||
),
|
||||
# 5. Collate eagle_content -> eagle_* tensors
|
||||
GrootEagleCollateStep(
|
||||
tokenizer_assets_repo=config.tokenizer_assets_repo,
|
||||
),
|
||||
# 6. Move to device
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
# Postprocessing: slice to env action dim and unnormalize to env scale, then move to CPU
|
||||
output_steps: list[ProcessorStep] = [
|
||||
GrootActionUnpackUnnormalizeStep(
|
||||
env_action_dim=env_action_dim,
|
||||
stats=padded_stats,
|
||||
normalize_min_max=True,
|
||||
),
|
||||
# Finally, move to CPU for env interaction
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# GR00T specific processor steps
|
||||
|
||||
|
||||
def _to_uint8_np_bhwc(img_t: torch.Tensor) -> np.ndarray:
|
||||
# img_t: (B, C, H, W) float in [0,1] or uint8
|
||||
if img_t.dtype.is_floating_point:
|
||||
img_t = (img_t.clamp(0, 1) * 255.0).to(torch.uint8)
|
||||
return rearrange(img_t.cpu().numpy(), "b c h w -> b h w c")
|
||||
|
||||
|
||||
def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO) -> ProcessorMixin:
|
||||
# Validate that the cache directory is ready. If not, instruct the user.
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
required = [
|
||||
cache_dir / "processor_config.json",
|
||||
cache_dir / "preprocessor_config.json",
|
||||
cache_dir / "image_processing_eagle2_5_vl_fast.py",
|
||||
]
|
||||
if not all(p.exists() for p in required):
|
||||
raise FileNotFoundError(
|
||||
f"[GROOT] Eagle processor cache at '{cache_dir}' is not populated. "
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_pack_inputs_v3")
|
||||
class GrootPackInputsStep(ProcessorStep):
|
||||
state_horizon: int = 1
|
||||
action_horizon: int = 16
|
||||
max_state_dim: int = 64
|
||||
max_action_dim: int = 32
|
||||
language_key: str = "task"
|
||||
formalize_language: bool = False
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
embodiment_mapping: dict[str, int] = field(
|
||||
default_factory=lambda: {
|
||||
"new_embodiment": 31, # Match original GR00T EMBODIMENT_TAG_MAPPING
|
||||
"oxe_droid": 17,
|
||||
"agibot_genie1": 26,
|
||||
"gr1": 24,
|
||||
"so100": 2,
|
||||
"unitree_g1": 3,
|
||||
}
|
||||
)
|
||||
# Min-max normalization (SO100-like) applied BEFORE padding
|
||||
normalize_min_max: bool = True
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
|
||||
def _align_vec(vec: Any, target_dim: int, *, default: float) -> torch.Tensor:
|
||||
t = torch.as_tensor(vec)
|
||||
t = t.flatten().to(
|
||||
dtype=torch.float32,
|
||||
device=next(
|
||||
(v.device for v in obs.values() if isinstance(v, torch.Tensor)), torch.device("cpu")
|
||||
),
|
||||
)
|
||||
d = int(t.shape[-1]) if t.numel() > 0 else 0
|
||||
if d == target_dim:
|
||||
return t
|
||||
if d < target_dim:
|
||||
pad = torch.full((target_dim - d,), default, dtype=t.dtype, device=t.device)
|
||||
return torch.cat([t, pad], dim=0)
|
||||
return t[:target_dim]
|
||||
|
||||
def _min_max_norm(x: torch.Tensor, key: str) -> torch.Tensor:
|
||||
if not self.normalize_min_max:
|
||||
return x
|
||||
if self.stats is None or key not in self.stats:
|
||||
return x
|
||||
stats_k = self.stats[key]
|
||||
last_dim = x.shape[-1]
|
||||
min_v = _align_vec(stats_k.get("min", torch.zeros(last_dim)), last_dim, default=0.0)
|
||||
max_v = _align_vec(stats_k.get("max", torch.ones(last_dim)), last_dim, default=1.0)
|
||||
denom = max_v - min_v
|
||||
mask = denom != 0
|
||||
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
|
||||
mapped = 2 * (x - min_v) / safe_denom - 1
|
||||
return torch.where(mask, mapped, torch.zeros_like(mapped))
|
||||
|
||||
# 1) Video (B, T=1, V, H, W, C) uint8
|
||||
img_keys = sorted([k for k in obs if k.startswith("observation.images.")])
|
||||
if not img_keys and "observation.image" in obs:
|
||||
img_keys = ["observation.image"]
|
||||
if img_keys:
|
||||
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
|
||||
video = np.stack(cams, axis=1) # (B, V, H, W, C)
|
||||
video = np.expand_dims(video, axis=1) # (B, 1, V, H, W, C)
|
||||
# GR00T validates that video.shape[3] == 3 (channels), so reorder to (B, T, V, C, H, W)
|
||||
video = np.transpose(video, (0, 1, 2, 5, 3, 4)) # (B, 1, V, C, H, W)
|
||||
obs["video"] = video
|
||||
# Drop raw images to avoid confusion downstream
|
||||
for k in img_keys:
|
||||
obs.pop(k, None)
|
||||
|
||||
# 2) Language (string)
|
||||
lang = comp.get(self.language_key)
|
||||
if isinstance(lang, list):
|
||||
lang = lang[0] if len(lang) > 0 else None
|
||||
if not lang:
|
||||
lang = "Perform the task."
|
||||
if self.formalize_language:
|
||||
lang = (lang or "").lower()
|
||||
lang = "".join(ch for ch in lang if ch.isalnum() or ch.isspace())
|
||||
comp["language"] = lang
|
||||
|
||||
# 3) State/state_mask -> (B, 1, max_state_dim)
|
||||
if "observation.state" in obs:
|
||||
state = obs["observation.state"] # (B, D)
|
||||
if state.dim() != 2:
|
||||
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
|
||||
bsz, d = state.shape
|
||||
# Normalize BEFORE padding
|
||||
if self.normalize_min_max:
|
||||
state = _min_max_norm(state, "observation.state")
|
||||
state = state.unsqueeze(1) # (B, 1, D)
|
||||
if d > self.max_state_dim:
|
||||
state = state[:, :, : self.max_state_dim]
|
||||
d = self.max_state_dim
|
||||
elif d < self.max_state_dim:
|
||||
pad = torch.zeros(bsz, 1, self.max_state_dim - d, dtype=state.dtype, device=state.device)
|
||||
state = torch.cat([state, pad], dim=2)
|
||||
state_mask = torch.zeros(bsz, 1, self.max_state_dim, dtype=torch.bool, device=state.device)
|
||||
state_mask[:, :, :d] = True
|
||||
obs["state"] = state
|
||||
obs["state_mask"] = state_mask
|
||||
|
||||
# 4) Action/action_mask -> (B, action_horizon, max_action_dim)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
# Normalize BEFORE temporal expansion/padding
|
||||
if self.normalize_min_max:
|
||||
if action.dim() == 2:
|
||||
action = _min_max_norm(action, "action")
|
||||
elif action.dim() == 3:
|
||||
b, t, d = action.shape
|
||||
flat = action.reshape(b * t, d)
|
||||
flat = _min_max_norm(flat, "action")
|
||||
action = flat.view(b, t, d)
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
|
||||
elif action.dim() == 3:
|
||||
b, t, d = action.shape
|
||||
if t < self.action_horizon:
|
||||
last = action[:, -1:, :]
|
||||
pad = last.repeat(1, self.action_horizon - t, 1)
|
||||
action = torch.cat([action, pad], dim=1)
|
||||
elif t > self.action_horizon:
|
||||
action = action[:, : self.action_horizon, :]
|
||||
else:
|
||||
raise ValueError(f"action must be (B, D) or (B, T, D), got {tuple(action.shape)}")
|
||||
|
||||
b, t, d = action.shape
|
||||
if d > self.max_action_dim:
|
||||
action = action[:, :, : self.max_action_dim]
|
||||
d = self.max_action_dim
|
||||
elif d < self.max_action_dim:
|
||||
pad = torch.zeros(b, t, self.max_action_dim - d, dtype=action.dtype, device=action.device)
|
||||
action = torch.cat([action, pad], dim=2)
|
||||
action_mask = torch.zeros(b, t, self.max_action_dim, dtype=torch.bool, device=action.device)
|
||||
action_mask[:, :, :d] = True
|
||||
transition[TransitionKey.ACTION] = action
|
||||
comp["action_mask"] = action_mask
|
||||
|
||||
# 5) Embodiment id as LongTensor (B,)
|
||||
emb_id = self.embodiment_mapping.get(self.embodiment_tag, 0)
|
||||
# Infer batch size/device from any tensor in obs or action
|
||||
bsz = None
|
||||
device = torch.device("cpu")
|
||||
for v in list(obs.values()) + [transition.get(TransitionKey.ACTION)]:
|
||||
if isinstance(v, torch.Tensor):
|
||||
bsz = v.shape[0]
|
||||
device = v.device
|
||||
break
|
||||
if bsz is None and "video" in obs and isinstance(obs["video"], np.ndarray):
|
||||
bsz = obs["video"].shape[0]
|
||||
if bsz is None:
|
||||
bsz = 1
|
||||
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.long, device=device)
|
||||
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return transition
|
||||
|
||||
# Pipeline API requirement: declare how features change (we keep it simple)
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a serializable dictionary of the processor's configuration.
|
||||
|
||||
Excludes 'stats' since they are saved separately via state_dict().
|
||||
"""
|
||||
return {
|
||||
"state_horizon": self.state_horizon,
|
||||
"action_horizon": self.action_horizon,
|
||||
"max_state_dim": self.max_state_dim,
|
||||
"max_action_dim": self.max_action_dim,
|
||||
"language_key": self.language_key,
|
||||
"formalize_language": self.formalize_language,
|
||||
"embodiment_tag": self.embodiment_tag,
|
||||
"embodiment_mapping": self.embodiment_mapping,
|
||||
"normalize_min_max": self.normalize_min_max,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns normalization statistics as a flat state dictionary.
|
||||
|
||||
This enables saving stats to safetensors files, similar to normalizer_processor.
|
||||
"""
|
||||
if not self.stats:
|
||||
return {}
|
||||
|
||||
flat: dict[str, torch.Tensor] = {}
|
||||
for key, sub in self.stats.items():
|
||||
for stat_name, value in sub.items():
|
||||
tensor = torch.as_tensor(value).cpu()
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Loads normalization statistics from a flat state dictionary.
|
||||
|
||||
This enables loading stats from safetensors files during from_pretrained.
|
||||
"""
|
||||
if not state:
|
||||
return
|
||||
|
||||
reconstructed: dict[str, dict[str, Any]] = {}
|
||||
for flat_key, tensor in state.items():
|
||||
if "." in flat_key:
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
if key not in reconstructed:
|
||||
reconstructed[key] = {}
|
||||
reconstructed[key][stat_name] = tensor
|
||||
|
||||
if reconstructed:
|
||||
self.stats = reconstructed
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_eagle_encode_v3")
|
||||
class GrootEagleEncodeStep(ProcessorStep):
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def proc(self) -> ProcessorMixin:
|
||||
if self._proc is None:
|
||||
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
|
||||
return self._proc
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
|
||||
if "video" not in obs:
|
||||
return transition
|
||||
|
||||
video = obs["video"] # (B, T, V, H, W, C) uint8
|
||||
lang = comp.get("language", "Perform the task.")
|
||||
if isinstance(lang, list):
|
||||
lang = lang[0] if len(lang) > 0 else "Perform the task."
|
||||
|
||||
bsz = video.shape[0]
|
||||
eagle_contents: list[dict[str, Any]] = []
|
||||
for b in range(bsz):
|
||||
vt = video[b] # (T, V, C, H, W) after reorder
|
||||
if vt.ndim != 5:
|
||||
# Fallback: assume (T, V, H, W, C)
|
||||
t, v, h, w, c = vt.shape
|
||||
flat = rearrange(vt, "t v h w c -> (t v) h w c")
|
||||
else:
|
||||
t, v, c, h, w = vt.shape
|
||||
flat = rearrange(vt, "t v c h w -> (t v) h w c")
|
||||
images = [Image.fromarray(flat[i]) for i in range(t * v)]
|
||||
# Format language as string list representation to match Original GROOT
|
||||
lang_formatted = str([lang])
|
||||
text_content = [{"type": "text", "text": lang_formatted}]
|
||||
image_content = [{"type": "image", "image": img} for img in images]
|
||||
conv = [{"role": "user", "content": image_content + text_content}]
|
||||
text_list = [self.proc.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)]
|
||||
img_inputs, vid_inputs = self.proc.process_vision_info(conv)
|
||||
eagle_contents.append(
|
||||
{
|
||||
"text_list": text_list,
|
||||
"image_inputs": img_inputs,
|
||||
"video_inputs": vid_inputs,
|
||||
}
|
||||
)
|
||||
|
||||
comp["eagle_content"] = eagle_contents
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return transition
|
||||
|
||||
# Pipeline API requirement: declare how features change (no schema change here)
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
# Original GR00T-style collate: converts eagle_content -> eagle_* tensors
|
||||
def collate(features: list[dict[str, Any]], eagle_processor: ProcessorMixin) -> dict[str, Any]:
|
||||
batch: dict[str, Any] = {}
|
||||
keys = features[0].keys()
|
||||
|
||||
for key in keys:
|
||||
values = [elem[key] for elem in features]
|
||||
|
||||
if key == "eagle_content":
|
||||
text_list: list[str] = []
|
||||
image_inputs: list[Any] = []
|
||||
for v in values:
|
||||
curr_text_list = v["text_list"]
|
||||
curr_image_inputs = v["image_inputs"]
|
||||
text_list += curr_text_list
|
||||
image_inputs += curr_image_inputs
|
||||
eagle_inputs = eagle_processor(
|
||||
text=text_list,
|
||||
images=image_inputs,
|
||||
images_kwargs={"min_dynamic_tiles": 1, "max_dynamic_tiles": 1, "use_thumbnail": False},
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
for k, v in eagle_inputs.items():
|
||||
k = "eagle_" + k
|
||||
batch[k] = v
|
||||
elif key in ("pixel_values", "image_grid_thw", "attention_mask", "input_ids"):
|
||||
# Concat in existing batch dimension.
|
||||
batch[key] = torch.cat(values)
|
||||
else:
|
||||
# state, state_mask, action and action_mask.
|
||||
# Stack to form the batch dimension.
|
||||
batch[key] = torch.from_numpy(np.stack(values))
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_eagle_collate_v3")
|
||||
class GrootEagleCollateStep(ProcessorStep):
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def proc(self) -> ProcessorMixin:
|
||||
if self._proc is None:
|
||||
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
|
||||
return self._proc
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
contents = comp.get("eagle_content")
|
||||
if not contents:
|
||||
return transition
|
||||
|
||||
# Build features list as original API expects: one dict per batch item
|
||||
features = [{"eagle_content": content} for content in contents]
|
||||
batched = collate(features, self.proc)
|
||||
|
||||
# Inject eagle_* tensors and remove the temporary content and raw video to free memory
|
||||
for k, v in batched.items():
|
||||
comp[k] = v
|
||||
comp.pop("eagle_content", None)
|
||||
obs.pop(
|
||||
"video", None
|
||||
) # The video has been fully encoded into eagle_* tensors, so we don't need the raw video anymore
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
|
||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
env_action_dim: int = 0
|
||||
# Apply inverse of min-max normalization if it was used in preprocessor
|
||||
normalize_min_max: bool = True
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if not isinstance(action, torch.Tensor):
|
||||
return transition
|
||||
|
||||
# Select last timestep and slice to env dimension
|
||||
if action.dim() == 3:
|
||||
action = action[:, -1, :]
|
||||
# Now action is (B, D_model)
|
||||
if self.env_action_dim and action.shape[-1] >= self.env_action_dim:
|
||||
action = action[..., : self.env_action_dim]
|
||||
|
||||
# Inverse min-max normalization mirroring _min_max_norm:
|
||||
# forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0
|
||||
# inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min
|
||||
if self.normalize_min_max and self.stats is not None:
|
||||
stats_k = self.stats.get("action", {})
|
||||
d = action.shape[-1]
|
||||
min_v = torch.as_tensor(
|
||||
stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device
|
||||
)
|
||||
max_v = torch.as_tensor(
|
||||
stats_k.get("max", torch.ones(d)), dtype=action.dtype, device=action.device
|
||||
)
|
||||
if min_v.numel() != d:
|
||||
min_v = torch.nn.functional.pad(min_v.flatten()[:d], (0, max(0, d - min_v.numel())))
|
||||
min_v = min_v.to(action.device, dtype=action.dtype)
|
||||
if max_v.numel() != d:
|
||||
max_v = torch.nn.functional.pad(max_v.flatten()[:d], (0, max(0, d - max_v.numel())))
|
||||
max_v = max_v.to(action.device, dtype=action.dtype)
|
||||
denom = max_v - min_v
|
||||
mask = denom != 0
|
||||
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
|
||||
inv = (action + 1.0) * 0.5 * safe_denom + min_v
|
||||
action = torch.where(mask, inv, min_v)
|
||||
|
||||
transition[TransitionKey.ACTION] = action
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a serializable dictionary of the processor's configuration.
|
||||
|
||||
Excludes 'stats' since they are saved separately via state_dict().
|
||||
"""
|
||||
return {
|
||||
"env_action_dim": self.env_action_dim,
|
||||
"normalize_min_max": self.normalize_min_max,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns normalization statistics as a flat state dictionary.
|
||||
|
||||
This enables saving stats to safetensors files, similar to normalizer_processor.
|
||||
"""
|
||||
if not self.stats:
|
||||
return {}
|
||||
|
||||
flat: dict[str, torch.Tensor] = {}
|
||||
for key, sub in self.stats.items():
|
||||
for stat_name, value in sub.items():
|
||||
tensor = torch.as_tensor(value).cpu()
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Loads normalization statistics from a flat state dictionary.
|
||||
|
||||
This enables loading stats from safetensors files during from_pretrained.
|
||||
"""
|
||||
if not state:
|
||||
return
|
||||
|
||||
reconstructed: dict[str, dict[str, Any]] = {}
|
||||
for flat_key, tensor in state.items():
|
||||
if "." in flat_key:
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
if key not in reconstructed:
|
||||
reconstructed[key] = {}
|
||||
reconstructed[key][stat_name] = tensor
|
||||
|
||||
if reconstructed:
|
||||
self.stats = reconstructed
|
||||
@@ -1,47 +0,0 @@
|
||||
from pathlib import Path
|
||||
from shutil import copytree
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
|
||||
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
|
||||
|
||||
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
|
||||
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
vendor_dir = Path(vendor_dir)
|
||||
|
||||
try:
|
||||
# Populate/refresh cache with vendor files to ensure a complete processor directory
|
||||
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
|
||||
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
|
||||
|
||||
required_assets = [
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"added_tokens.json",
|
||||
"chat_template.json",
|
||||
"special_tokens_map.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
|
||||
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
|
||||
|
||||
for fname in required_assets:
|
||||
dst = cache_dir / fname
|
||||
if not dst.exists():
|
||||
print(f"[GROOT] Fetching {fname}")
|
||||
hf_hub_download(
|
||||
repo_id=assets_repo,
|
||||
filename=fname,
|
||||
repo_type="model",
|
||||
local_dir=str(cache_dir),
|
||||
)
|
||||
@@ -20,7 +20,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -48,9 +47,6 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
@@ -79,8 +75,6 @@ class PI0Config(PreTrainedConfig):
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
@@ -19,12 +19,11 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -43,7 +42,6 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -53,12 +51,6 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -511,10 +503,9 @@ class PaliGemmaWithExpertModel(
|
||||
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI0 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
|
||||
def __init__(self, config: PI0Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -569,9 +560,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -768,15 +756,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
@@ -818,41 +798,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -916,8 +869,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI0 model
|
||||
self.init_rtc_processor()
|
||||
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
self.model = PI0Pytorch(config)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1107,22 +1059,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1201,10 +1137,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1216,7 +1148,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1225,8 +1157,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs)
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@@ -47,9 +46,6 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
@@ -79,8 +75,6 @@ class PI05Config(PreTrainedConfig):
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
@@ -19,12 +19,11 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -43,7 +42,6 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -52,12 +50,6 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -510,10 +502,9 @@ class PaliGemmaWithExpertModel(
|
||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI05 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
||||
def __init__(self, config: PI05Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -565,9 +556,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -743,16 +731,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
@@ -791,40 +770,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -887,8 +839,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI05 model
|
||||
self.init_rtc_processor()
|
||||
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
self.model = PI05Pytorch(config)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1084,22 +1035,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1174,10 +1109,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1189,7 +1120,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1197,8 +1128,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
# Sample actions using the model (no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
# Real-Time Chunking (RTC) Module
|
||||
|
||||
This module implements Real-Time Chunking and related adaptive inference techniques for robotics policies in LeRobot.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-Time Chunking (RTC) addresses the challenge of real-time inference in action chunking policies by treating chunk generation as an inpainting problem. It strategically handles overlapping timesteps between action chunks using prefix attention mechanisms.
|
||||
|
||||
It is particularly effective for handling long-horizon inference in robotics policies.
|
||||
|
||||
## Integration with Policies
|
||||
|
||||
RTC can be integrated with any policy that supports flow mathicng for chunking:
|
||||
|
||||
- **SmolVLA**: Vision-language-action model with RTC support
|
||||
- **Pi0**: Action prediction model with adaptive chunking
|
||||
- **Pi05**: Action prediction model with adaptive chunking
|
||||
|
||||
## Original Implementation
|
||||
|
||||
This implementation is based on Physical Intelligence's Kinetix RTC:
|
||||
|
||||
- [Original RTC implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214)
|
||||
- [Kinetix GitHub Repository](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
|
||||
|
||||
## References
|
||||
|
||||
- [Real Time Chunking Paper](https://www.physicalintelligence.company/research/real_time_chunking)
|
||||
- [Physical Intelligence Kinetix](https://github.com/Physical-Intelligence/real-time-chunking-kinetix)
|
||||
|
||||
## How to run
|
||||
|
||||
### Check with data from the dataset
|
||||
|
||||
```bash
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--seed=42
|
||||
```
|
||||
|
||||
This script will evaluate RTC on a data from a dataset and save the results to a file, u can check the results in the `rtc_debug_output` directory.
|
||||
|
||||
The example output should look like this:
|
||||

|
||||
|
||||
It shows how flow matching works with RTC and without it. The chart shows values of action predictions for each timestep. The colour shows the the generation progress. The blue ones - earlier timesteps, the yellow ones - later timesteps. The red line is the ground truth (previous action chunk).
|
||||
@@ -1,219 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action queue management for Real-Time Chunking (RTC).
|
||||
|
||||
This module provides ActionQueue, a thread-safe queue for managing action chunks
|
||||
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
|
||||
handling action merging and leftover tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionQueue:
|
||||
"""Thread-safe queue for managing action chunks in real-time control.
|
||||
|
||||
This queue handles two types of action sequences:
|
||||
- Original actions: Used for RTC to compute leftovers from previous chunks
|
||||
- Processed actions: Post-processed actions ready for robot execution
|
||||
|
||||
The queue operates in two modes:
|
||||
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
|
||||
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
|
||||
|
||||
Attributes:
|
||||
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
|
||||
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
|
||||
last_index (int): Current consumption index in the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: RTCConfig):
|
||||
"""Initialize the action queue.
|
||||
|
||||
Args:
|
||||
cfg: RTC configuration controlling queue behavior.
|
||||
"""
|
||||
self.queue = None # Processed actions for robot rollout
|
||||
self.original_queue = None # Original actions for RTC
|
||||
self.lock = Lock()
|
||||
self.last_index = 0
|
||||
self.cfg = cfg
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next action from the queue.
|
||||
|
||||
Returns:
|
||||
Tensor | None: The next action (action_dim,) or None if queue is empty.
|
||||
Returns a clone to prevent external modifications.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.queue is None or self.last_index >= len(self.queue):
|
||||
return None
|
||||
|
||||
action = self.queue[self.last_index]
|
||||
self.last_index += 1
|
||||
return action.clone()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Get the number of remaining actions in the queue.
|
||||
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
These are the unconsumed actions from the current chunk, which will be
|
||||
used by RTC to compute corrections for the next chunk.
|
||||
|
||||
Returns:
|
||||
Tensor | None: Remaining original actions (remaining_steps, action_dim),
|
||||
or None if no original queue exists.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.original_queue is None:
|
||||
return None
|
||||
return self.original_queue[self.last_index :]
|
||||
|
||||
def merge(
|
||||
self,
|
||||
original_actions: Tensor,
|
||||
processed_actions: Tensor,
|
||||
real_delay: int,
|
||||
action_index_before_inference: int | None = 0,
|
||||
):
|
||||
"""Merge new actions into the queue.
|
||||
|
||||
This method operates differently based on RTC mode:
|
||||
- RTC enabled: Replaces the queue, accounting for inference delay
|
||||
- RTC disabled: Appends to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy (time_steps, action_dim).
|
||||
processed_actions: Post-processed actions for robot (time_steps, action_dim).
|
||||
real_delay: Number of time steps of inference delay.
|
||||
action_index_before_inference: Index before inference started, for validation.
|
||||
"""
|
||||
with self.lock:
|
||||
self._check_delays(real_delay, action_index_before_inference)
|
||||
|
||||
if self.cfg.enabled:
|
||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||
return
|
||||
|
||||
self._append_actions_queue(original_actions, processed_actions)
|
||||
|
||||
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
||||
"""Replace the queue with new actions (RTC mode).
|
||||
|
||||
Discards the first `real_delay` actions since they correspond to the time
|
||||
spent during inference, when the robot was executing previous actions.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
real_delay: Number of time steps to skip due to inference delay.
|
||||
"""
|
||||
self.original_queue = original_actions[real_delay:].clone()
|
||||
self.queue = processed_actions[real_delay:].clone()
|
||||
|
||||
logger.debug(f"original_actions shape: {self.original_queue.shape}")
|
||||
logger.debug(f"processed_actions shape: {self.queue.shape}")
|
||||
logger.debug(f"real_delay: {real_delay}")
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
||||
"""Append new actions to the queue (non-RTC mode).
|
||||
|
||||
Removes already-consumed actions and appends new ones, maintaining
|
||||
queue continuity without replacement.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
"""
|
||||
if self.queue is None:
|
||||
self.original_queue = original_actions.clone()
|
||||
self.queue = processed_actions.clone()
|
||||
return
|
||||
|
||||
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
||||
self.original_queue = self.original_queue[self.last_index :]
|
||||
|
||||
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
||||
self.queue = self.queue[self.last_index :]
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||
"""Validate that computed delays match expectations.
|
||||
|
||||
Compares the delay computed from inference latency with the actual
|
||||
number of actions consumed during inference.
|
||||
|
||||
Args:
|
||||
real_delay: Delay computed from inference latency.
|
||||
action_index_before_inference: Action index when inference started.
|
||||
"""
|
||||
if action_index_before_inference is None:
|
||||
return
|
||||
|
||||
indexes_diff = self.last_index - action_index_before_inference
|
||||
if indexes_diff != real_delay:
|
||||
# Let's check that action index difference (real delay calculated based on action queue)
|
||||
# is the same as delay calculated based on inference latency
|
||||
logger.warning(
|
||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
|
||||
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||
)
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
|
||||
|
||||
Based on:
|
||||
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCConfig:
|
||||
"""Configuration for Real Time Chunking (RTC) inference.
|
||||
|
||||
RTC improves real-time inference by treating chunk generation as an inpainting problem,
|
||||
strategically handling overlapping timesteps between action chunks using prefix attention.
|
||||
"""
|
||||
|
||||
# Infrastructure
|
||||
enabled: bool = False
|
||||
|
||||
# Core RTC settings
|
||||
# Todo change to exp
|
||||
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
|
||||
max_guidance_weight: float = 10.0
|
||||
execution_horizon: int = 10
|
||||
|
||||
# Debug settings
|
||||
debug: bool = False
|
||||
debug_maxlen: int = 100
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate RTC configuration parameters."""
|
||||
if self.max_guidance_weight <= 0:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
@@ -1,233 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Debug information handler for Real-Time Chunking (RTC)."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugStep:
|
||||
"""Container for debug information from a single denoising step.
|
||||
|
||||
Attributes:
|
||||
step_idx (int): Step index/counter.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
time (float | Tensor | None): Time parameter.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
metadata (dict[str, Any]): Additional metadata.
|
||||
"""
|
||||
|
||||
step_idx: int = 0
|
||||
x_t: Tensor | None = None
|
||||
v_t: Tensor | None = None
|
||||
x1_t: Tensor | None = None
|
||||
correction: Tensor | None = None
|
||||
err: Tensor | None = None
|
||||
weights: Tensor | None = None
|
||||
guidance_weight: float | Tensor | None = None
|
||||
time: float | Tensor | None = None
|
||||
inference_delay: int | None = None
|
||||
execution_horizon: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
|
||||
"""Convert debug step to dictionary.
|
||||
|
||||
Args:
|
||||
include_tensors (bool): If True, include tensor values. If False, only include
|
||||
tensor statistics (shape, mean, std, min, max).
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the debug step.
|
||||
"""
|
||||
result = {
|
||||
"step_idx": self.step_idx,
|
||||
"guidance_weight": (
|
||||
self.guidance_weight.item()
|
||||
if isinstance(self.guidance_weight, Tensor)
|
||||
else self.guidance_weight
|
||||
),
|
||||
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
|
||||
"inference_delay": self.inference_delay,
|
||||
"execution_horizon": self.execution_horizon,
|
||||
"metadata": self.metadata.copy(),
|
||||
}
|
||||
|
||||
# Add tensor information
|
||||
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
|
||||
for field_name in tensor_fields:
|
||||
tensor = getattr(self, field_name)
|
||||
if tensor is not None:
|
||||
if include_tensors:
|
||||
result[field_name] = tensor.detach().cpu()
|
||||
else:
|
||||
result[f"{field_name}_stats"] = {
|
||||
"shape": tuple(tensor.shape),
|
||||
"mean": tensor.mean().item(),
|
||||
"std": tensor.std().item(),
|
||||
"min": tensor.min().item(),
|
||||
"max": tensor.max().item(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""Collects and manages debug information for RTC processing.
|
||||
|
||||
This tracker stores debug information from recent denoising steps in a dictionary,
|
||||
using time as the key for efficient lookups and updates.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether debug collection is enabled.
|
||||
maxlen (int | None): Optional sliding window size. If provided, only the
|
||||
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
|
||||
"""
|
||||
|
||||
def __init__(self, enabled: bool = False, maxlen: int = 100):
|
||||
self.enabled = enabled
|
||||
self._steps = {} if enabled else None # Dictionary with time as key
|
||||
self._maxlen = maxlen
|
||||
self._step_counter = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all recorded debug information."""
|
||||
if self.enabled and self._steps is not None:
|
||||
self._steps.clear()
|
||||
self._step_counter = 0
|
||||
|
||||
@torch._dynamo.disable
|
||||
def track(
|
||||
self,
|
||||
time: float | Tensor,
|
||||
x_t: Tensor | None = None,
|
||||
v_t: Tensor | None = None,
|
||||
x1_t: Tensor | None = None,
|
||||
correction: Tensor | None = None,
|
||||
err: Tensor | None = None,
|
||||
weights: Tensor | None = None,
|
||||
guidance_weight: float | Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
**metadata,
|
||||
) -> None:
|
||||
"""Track debug information for a denoising step at a given time.
|
||||
|
||||
If a step with the given time already exists, it will be updated with the new data.
|
||||
Otherwise, a new step will be created. Only non-None fields are updated/set.
|
||||
|
||||
Note: This method is excluded from torch.compile to avoid graph breaks from
|
||||
operations like .item() which are incompatible with compiled graphs.
|
||||
|
||||
Args:
|
||||
time (float | Tensor): Time parameter - used as the key to identify the step.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction.
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
**metadata: Additional metadata to store.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Convert time to float and round to avoid float precision issues
|
||||
time_value = time.item() if isinstance(time, Tensor) else time
|
||||
time_key = round(time_value, 6) # Use rounded time as dictionary key
|
||||
|
||||
# Check if step with this time already exists
|
||||
if time_key in self._steps:
|
||||
# Update existing step with non-None fields
|
||||
existing_step = self._steps[time_key]
|
||||
if x_t is not None:
|
||||
existing_step.x_t = x_t.detach().clone()
|
||||
if v_t is not None:
|
||||
existing_step.v_t = v_t.detach().clone()
|
||||
if x1_t is not None:
|
||||
existing_step.x1_t = x1_t.detach().clone()
|
||||
if correction is not None:
|
||||
existing_step.correction = correction.detach().clone()
|
||||
if err is not None:
|
||||
existing_step.err = err.detach().clone()
|
||||
if weights is not None:
|
||||
existing_step.weights = weights.detach().clone()
|
||||
if guidance_weight is not None:
|
||||
existing_step.guidance_weight = guidance_weight
|
||||
if inference_delay is not None:
|
||||
existing_step.inference_delay = inference_delay
|
||||
if execution_horizon is not None:
|
||||
existing_step.execution_horizon = execution_horizon
|
||||
if metadata:
|
||||
existing_step.metadata.update(metadata)
|
||||
else:
|
||||
# Create new step
|
||||
step = DebugStep(
|
||||
step_idx=self._step_counter,
|
||||
x_t=x_t.detach().clone() if x_t is not None else None,
|
||||
v_t=v_t.detach().clone() if v_t is not None else None,
|
||||
x1_t=x1_t.detach().clone() if x1_t is not None else None,
|
||||
correction=correction.detach().clone() if correction is not None else None,
|
||||
err=err.detach().clone() if err is not None else None,
|
||||
weights=weights.detach().clone() if weights is not None else None,
|
||||
guidance_weight=guidance_weight,
|
||||
time=time_value,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Add to dictionary
|
||||
self._steps[time_key] = step
|
||||
self._step_counter += 1
|
||||
|
||||
# Enforce maxlen if set
|
||||
if self._maxlen is not None and len(self._steps) > self._maxlen:
|
||||
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
|
||||
oldest_key = next(iter(self._steps))
|
||||
del self._steps[oldest_key]
|
||||
|
||||
def get_all_steps(self) -> list[DebugStep]:
|
||||
"""Get all recorded debug steps.
|
||||
|
||||
Returns:
|
||||
List of all DebugStep objects (may be empty if disabled).
|
||||
"""
|
||||
if not self.enabled or self._steps is None:
|
||||
return []
|
||||
|
||||
return list(self._steps.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of recorded debug steps."""
|
||||
if not self.enabled or self._steps is None:
|
||||
return 0
|
||||
return len(self._steps)
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Visualization utilities for RTC debug information."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RTCDebugVisualizer:
|
||||
"""Visualizer for RTC debug information.
|
||||
|
||||
This class provides methods to visualize debug information collected by the Tracker,
|
||||
including corrections, errors, weights, and guidance weights over denoising steps.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def plot_waypoints(
|
||||
axes,
|
||||
tensor,
|
||||
start_from: int = 0,
|
||||
color: str = "blue",
|
||||
label: str = "",
|
||||
alpha: float = 0.7,
|
||||
linewidth: float = 2,
|
||||
marker: str | None = None,
|
||||
markersize: int = 4,
|
||||
):
|
||||
"""Plot trajectories across multiple dimensions.
|
||||
|
||||
This function plots a tensor's values across time for multiple dimensions,
|
||||
with each dimension plotted on a separate axis.
|
||||
|
||||
Args:
|
||||
axes: Array of matplotlib axes (one for each dimension).
|
||||
tensor: The tensor to plot (can be torch.Tensor or numpy array).
|
||||
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
|
||||
start_from: Starting index for the x-axis.
|
||||
color: Color for the plot lines.
|
||||
label: Label for the plot legend.
|
||||
alpha: Transparency level for the plot.
|
||||
linewidth: Width of the plot lines.
|
||||
marker: Marker style for data points (e.g., 'o', 's', '^').
|
||||
markersize: Size of the markers.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Handle None tensor
|
||||
if tensor is None:
|
||||
return
|
||||
|
||||
# Convert tensor to numpy if needed
|
||||
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
|
||||
|
||||
# Handle different tensor shapes
|
||||
if tensor_np.ndim == 3:
|
||||
# If batch dimension present, take first batch
|
||||
tensor_np = tensor_np[0]
|
||||
elif tensor_np.ndim == 1:
|
||||
# If 1D, reshape to (time_steps, 1)
|
||||
tensor_np = tensor_np.reshape(-1, 1)
|
||||
|
||||
# Get dimensions
|
||||
time_steps, num_dims = tensor_np.shape
|
||||
|
||||
# Create x-axis indices
|
||||
x_indices = np.arange(start_from, start_from + time_steps)
|
||||
|
||||
# Plot each dimension on its corresponding axis
|
||||
num_axes = len(axes) if hasattr(axes, "__len__") else 1
|
||||
for dim_idx in range(min(num_dims, num_axes)):
|
||||
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
|
||||
|
||||
# Plot the trajectory
|
||||
if marker:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
marker=marker,
|
||||
markersize=markersize,
|
||||
)
|
||||
else:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
)
|
||||
|
||||
# Add grid and labels if not already present
|
||||
if not ax.xaxis.get_label().get_text():
|
||||
ax.set_xlabel("Step", fontsize=10)
|
||||
if not ax.yaxis.get_label().get_text():
|
||||
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Add legend if label provided and this is the first dimension
|
||||
if label and dim_idx == 0:
|
||||
ax.legend(loc="best", fontsize=8)
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.3 MiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user