mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
8 Commits
security-f
...
security-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aba8beddda | ||
|
|
85de893fa7 | ||
|
|
a4c66e530b | ||
|
|
a225127527 | ||
|
|
e489ba24fc | ||
|
|
d324ffe810 | ||
|
|
1a24f770d3 | ||
|
|
92fba37225 |
3
.github/workflows/fast_tests.yml
vendored
3
.github/workflows/fast_tests.yml
vendored
@@ -44,7 +44,7 @@ permissions:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
|
|
||||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -91,6 +91,7 @@ jobs:
|
|||||||
run: uv sync --extra "test"
|
run: uv sync --extra "test"
|
||||||
|
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
uv run hf auth whoami
|
uv run hf auth whoami
|
||||||
|
|||||||
9
.github/workflows/full_tests.yml
vendored
9
.github/workflows/full_tests.yml
vendored
@@ -37,7 +37,7 @@ permissions:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||||
|
|
||||||
# Ensures that only the latest action is built, canceling older runs.
|
# Ensures that only the latest action is built, canceling older runs.
|
||||||
@@ -89,6 +89,7 @@ jobs:
|
|||||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||||
|
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
uv run hf auth whoami
|
uv run hf auth whoami
|
||||||
@@ -181,11 +182,12 @@ jobs:
|
|||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
hf auth whoami
|
hf auth whoami
|
||||||
- name: Fix ptxas permissions
|
- name: Fix ptxas permissions
|
||||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||||
- name: Run pytest on GPU
|
- name: Run pytest on GPU
|
||||||
run: pytest tests -vv --maxfail=10
|
run: pytest tests -vv --maxfail=10
|
||||||
- name: Run end-to-end tests
|
- name: Run end-to-end tests
|
||||||
@@ -200,7 +202,6 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Get Docker Hub Token and Delete Image
|
- name: Get Docker Hub Token and Delete Image
|
||||||
# zizmor: ignore[template-injection]
|
|
||||||
env:
|
env:
|
||||||
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||||
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||||
@@ -232,4 +233,4 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# TODO(Steven): Check dockerimages pull in ubuntu
|
# TODO(Steven): Check dockerimages pull in ubuntu
|
||||||
8
.github/workflows/nightly.yml
vendored
8
.github/workflows/nightly.yml
vendored
@@ -28,7 +28,7 @@ on:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||||
|
|
||||||
@@ -132,6 +132,7 @@ jobs:
|
|||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
hf auth whoami
|
hf auth whoami
|
||||||
@@ -164,6 +165,7 @@ jobs:
|
|||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
hf auth whoami
|
hf auth whoami
|
||||||
@@ -197,6 +199,7 @@ jobs:
|
|||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
hf auth whoami
|
hf auth whoami
|
||||||
@@ -206,5 +209,4 @@ jobs:
|
|||||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||||
|
|
||||||
- name: Run multi-GPU training tests
|
- name: Run multi-GPU training tests
|
||||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
run: pytest -vv tests/training/
|
||||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
|
||||||
|
|||||||
2
.github/workflows/quality.yml
vendored
2
.github/workflows/quality.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Run pre-commit hooks
|
- name: Run pre-commit hooks
|
||||||
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
||||||
|
|||||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -22,7 +22,7 @@ on:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# This job builds the Python package and publishes it to PyPI
|
# This job builds the Python package and publishes it to PyPI
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Extract Version
|
- name: Extract Version
|
||||||
id: extract_info
|
id: extract_info
|
||||||
@@ -83,14 +83,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Remove Tags with Git dependencies
|
|
||||||
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
|
||||||
run: |
|
|
||||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
|
||||||
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
|
||||||
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
|
||||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
|
||||||
|
|
||||||
- name: Install build dependencies
|
- name: Install build dependencies
|
||||||
run: python -m pip install build
|
run: python -m pip install build
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/unbound_deps_tests.yml
vendored
4
.github/workflows/unbound_deps_tests.yml
vendored
@@ -29,7 +29,7 @@ permissions:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||||
|
|
||||||
# Ensures that only the latest action is built, canceling older runs.
|
# Ensures that only the latest action is built, canceling older runs.
|
||||||
@@ -81,6 +81,7 @@ jobs:
|
|||||||
- name: Install lerobot with all extras
|
- name: Install lerobot with all extras
|
||||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
uv run hf auth whoami
|
uv run hf auth whoami
|
||||||
@@ -154,6 +155,7 @@ jobs:
|
|||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
- name: Login to Hugging Face
|
- name: Login to Hugging Face
|
||||||
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
hf auth whoami
|
hf auth whoami
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.12
|
||||||
|
|
||||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ repos:
|
|||||||
rev: v3.21.0
|
rev: v3.21.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
args: [--py310-plus]
|
args: [--py312-plus]
|
||||||
|
|
||||||
##### Markdown Quality #####
|
##### Markdown Quality #####
|
||||||
- repo: https://github.com/rbubley/mirrors-prettier
|
- repo: https://github.com/rbubley/mirrors-prettier
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
|
|||||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||||
|
|
||||||
# Define Python version argument
|
# Define Python version argument
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.12
|
||||||
|
|
||||||
# Configure environment variables
|
# Configure environment variables
|
||||||
ENV DEBIAN_FRONTEND=noninteractive \
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
# docker run -it --rm lerobot-user
|
# docker run -it --rm lerobot-user
|
||||||
|
|
||||||
# Configure the base image
|
# Configure the base image
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.12
|
||||||
FROM python:${PYTHON_VERSION}-slim
|
FROM python:${PYTHON_VERSION}-slim
|
||||||
|
|
||||||
# Configure environment variables
|
# Configure environment variables
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
# your policy-specific dependencies
|
# your policy-specific dependencies
|
||||||
]
|
]
|
||||||
requires-python = ">= 3.11"
|
requires-python = ">= 3.12"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
build-backend = # your-build-backend
|
build-backend = # your-build-backend
|
||||||
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
|
|||||||
# modeling_my_custom_policy.py
|
# modeling_my_custom_policy.py
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||||
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
|
|||||||
config_class = MyCustomPolicyConfig
|
config_class = MyCustomPolicyConfig
|
||||||
name = "my_custom_policy"
|
name = "my_custom_policy"
|
||||||
|
|
||||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||||
super().__init__(config, dataset_stats)
|
super().__init__(config, dataset_stats)
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
@@ -102,7 +102,7 @@ Create processor functions:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# processor_my_custom_policy.py
|
# processor_my_custom_policy.py
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
|
|||||||
### Hardware
|
### Hardware
|
||||||
|
|
||||||
- EarthRover Mini robot
|
- EarthRover Mini robot
|
||||||
- Computer with Python 3.10 or newer
|
- Computer with Python 3.12 or newer
|
||||||
- Internet connection
|
- Internet connection
|
||||||
|
|
||||||
### Setting Up the Frodobots SDK
|
### Setting Up the Frodobots SDK
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||||
|
|
||||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||||
|
|
||||||
@@ -11,10 +11,10 @@ bash Miniforge3-$(uname)-$(uname -m).sh
|
|||||||
|
|
||||||
## Step 2: Environment Setup
|
## Step 2: Environment Setup
|
||||||
|
|
||||||
Create a virtual environment with Python 3.10, using conda:
|
Create a virtual environment with Python 3.12, using conda:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda create -y -n lerobot python=3.10
|
conda create -y -n lerobot python=3.12
|
||||||
```
|
```
|
||||||
|
|
||||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
||||||
@@ -90,9 +90,6 @@ _Replace `[...]` with your desired features._
|
|||||||
For a full list of optional dependencies, see:
|
For a full list of optional dependencies, see:
|
||||||
https://pypi.org/project/lerobot/
|
https://pypi.org/project/lerobot/
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||||
|
|||||||
@@ -34,11 +34,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
|
|||||||
pip install -e ".[pi]"
|
pip install -e ".[pi]"
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
|
||||||
>
|
|
||||||
> This will be solved in the next patch release
|
|
||||||
|
|
||||||
## Training Data and Capabilities
|
## Training Data and Capabilities
|
||||||
|
|
||||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||||
|
|||||||
@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
|
|||||||
pip install -e ".[pi]"
|
pip install -e ".[pi]"
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
|
||||||
>
|
|
||||||
> This will be solved in the next patch release
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||||
|
|||||||
@@ -43,11 +43,6 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
|||||||
pip install -e ".[pi]"
|
pip install -e ".[pi]"
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
|
||||||
>
|
|
||||||
> This will be solved in the next patch release
|
|
||||||
|
|
||||||
## Training a Custom FAST Tokenizer
|
## Training a Custom FAST Tokenizer
|
||||||
|
|
||||||
You have two options for the FAST tokenizer:
|
You have two options for the FAST tokenizer:
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ SSH into the robot and install LeRobot:
|
|||||||
```bash
|
```bash
|
||||||
ssh unitree@<YOUR_ROBOT_IP>
|
ssh unitree@<YOUR_ROBOT_IP>
|
||||||
|
|
||||||
conda create -y -n lerobot python=3.10
|
conda create -y -n lerobot python=3.12
|
||||||
conda activate lerobot
|
conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
cd lerobot
|
cd lerobot
|
||||||
@@ -153,7 +153,7 @@ With the robot server running, you can now control the robot remotely. Let's lau
|
|||||||
### Step 1: Install LeRobot on your machine
|
### Step 1: Install LeRobot on your machine
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda create -y -n lerobot python=3.10
|
conda create -y -n lerobot python=3.12
|
||||||
conda activate lerobot
|
conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
cd lerobot
|
cd lerobot
|
||||||
|
|||||||
490
examples/dataset/slurm_compute_rabc.py
Normal file
490
examples/dataset/slurm_compute_rabc.py
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SLURM-distributed SARM RA-BC annotation pipeline.
|
||||||
|
|
||||||
|
Computes SARM progress values for all frames in a dataset, distributed across
|
||||||
|
SLURM workers, then merges the shards into a single sarm_progress.parquet.
|
||||||
|
|
||||||
|
Two subcommands, each a separate SLURM submission:
|
||||||
|
|
||||||
|
compute – N workers, each computes progress for a subset of episodes
|
||||||
|
aggregate – 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python slurm_compute_rabc.py compute \\
|
||||||
|
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||||
|
--stride 10 --device cpu --workers 50 --partition cpu
|
||||||
|
|
||||||
|
python slurm_compute_rabc.py aggregate \\
|
||||||
|
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||||
|
--partition cpu --push-to-hub
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
|
from datatrove.pipeline.base import PipelineStep
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeProgressShards(PipelineStep):
|
||||||
|
"""Each worker computes SARM progress for its assigned episodes."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if stride < 1:
|
||||||
|
raise ValueError(f"stride must be >= 1, got {stride}")
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.reward_model_path = reward_model_path
|
||||||
|
self.stride = stride
|
||||||
|
self.head_mode = head_mode
|
||||||
|
self.device = device
|
||||||
|
self.shard_dir = shard_dir
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||||
|
generate_all_frame_indices,
|
||||||
|
interpolate_progress,
|
||||||
|
load_sarm_resources,
|
||||||
|
)
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
dataset, reward_model, preprocess = load_sarm_resources(
|
||||||
|
self.repo_id,
|
||||||
|
self.reward_model_path,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(preprocess, "eval"):
|
||||||
|
preprocess.eval()
|
||||||
|
for step in preprocess.steps:
|
||||||
|
if hasattr(step, "eval"):
|
||||||
|
step.eval()
|
||||||
|
|
||||||
|
image_key = reward_model.config.image_key
|
||||||
|
state_key = reward_model.config.state_key
|
||||||
|
frame_gap = reward_model.config.frame_gap
|
||||||
|
center_idx = reward_model.config.n_obs_steps // 2
|
||||||
|
|
||||||
|
dual_mode = reward_model.config.uses_dual_heads
|
||||||
|
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
|
||||||
|
compute_dense = self.head_mode in ("dense", "both") and dual_mode
|
||||||
|
|
||||||
|
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
|
||||||
|
if not my_episodes:
|
||||||
|
logging.info(f"Rank {rank}: no episodes assigned")
|
||||||
|
return
|
||||||
|
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
|
||||||
|
|
||||||
|
all_rows = []
|
||||||
|
|
||||||
|
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
|
||||||
|
ep = dataset.meta.episodes[ep_idx]
|
||||||
|
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
|
||||||
|
task = dataset[ep_start].get("task", "perform the task")
|
||||||
|
|
||||||
|
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||||
|
if self.stride > 1:
|
||||||
|
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
|
||||||
|
if (ep_end - 1) not in compute_indices:
|
||||||
|
compute_indices.append(ep_end - 1)
|
||||||
|
compute_indices = sorted(set(compute_indices))
|
||||||
|
else:
|
||||||
|
compute_indices = all_ep_indices
|
||||||
|
|
||||||
|
frame_results = {}
|
||||||
|
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
|
||||||
|
try:
|
||||||
|
sample = dataset[qi]
|
||||||
|
batch = {
|
||||||
|
image_key: sample[image_key],
|
||||||
|
"task": task,
|
||||||
|
"index": qi,
|
||||||
|
"episode_index": ep_idx,
|
||||||
|
}
|
||||||
|
if state_key in sample:
|
||||||
|
batch[state_key] = sample[state_key]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
processed = preprocess(batch)
|
||||||
|
vf = processed["video_features"].to(self.device)
|
||||||
|
tf = processed["text_features"].to(self.device)
|
||||||
|
sf = processed.get("state_features")
|
||||||
|
if sf is not None:
|
||||||
|
sf = sf.to(self.device)
|
||||||
|
lengths = processed.get("lengths")
|
||||||
|
|
||||||
|
sparse_val = dense_val = np.nan
|
||||||
|
if compute_sparse:
|
||||||
|
r = reward_model.calculate_rewards(
|
||||||
|
text_embeddings=tf,
|
||||||
|
video_embeddings=vf,
|
||||||
|
state_features=sf,
|
||||||
|
lengths=lengths,
|
||||||
|
return_all_frames=True,
|
||||||
|
head_mode="sparse",
|
||||||
|
)
|
||||||
|
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||||
|
if compute_dense:
|
||||||
|
r = reward_model.calculate_rewards(
|
||||||
|
text_embeddings=tf,
|
||||||
|
video_embeddings=vf,
|
||||||
|
state_features=sf,
|
||||||
|
lengths=lengths,
|
||||||
|
return_all_frames=True,
|
||||||
|
head_mode="dense",
|
||||||
|
)
|
||||||
|
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||||
|
|
||||||
|
frame_results[qi] = (sparse_val, dense_val)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed frame {qi}: {e}")
|
||||||
|
|
||||||
|
if not frame_results:
|
||||||
|
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Interpolate to all frames in this episode
|
||||||
|
computed_idx = np.array(sorted(frame_results.keys()))
|
||||||
|
all_frame_arr = np.arange(ep_start, ep_end)
|
||||||
|
|
||||||
|
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
|
||||||
|
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
|
||||||
|
|
||||||
|
if self.stride > 1 and len(computed_idx) > 1:
|
||||||
|
if compute_sparse:
|
||||||
|
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
|
||||||
|
if compute_dense:
|
||||||
|
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
|
||||||
|
output_frames = all_frame_arr
|
||||||
|
else:
|
||||||
|
# Use only successfully computed frames to avoid indexing mismatch on failures
|
||||||
|
output_frames = computed_idx
|
||||||
|
|
||||||
|
for i, fi in enumerate(output_frames):
|
||||||
|
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
|
||||||
|
if compute_sparse:
|
||||||
|
row["progress_sparse"] = float(sparse_vals[i])
|
||||||
|
if compute_dense:
|
||||||
|
row["progress_dense"] = float(dense_vals[i])
|
||||||
|
all_rows.append(row)
|
||||||
|
|
||||||
|
if all_rows:
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
|
||||||
|
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||||
|
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||||
|
shard_dir = Path(self.shard_dir)
|
||||||
|
shard_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out = shard_dir / f"shard_{rank:05d}.parquet"
|
||||||
|
pq.write_table(table, out)
|
||||||
|
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateProgress(PipelineStep):
|
||||||
|
"""Merge all shard parquets into final sarm_progress.parquet."""
|
||||||
|
|
||||||
|
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
|
||||||
|
super().__init__()
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.reward_model_path = reward_model_path
|
||||||
|
self.shard_dir = shard_dir
|
||||||
|
self.push_to_hub = push_to_hub
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
shard_dir = Path(self.shard_dir)
|
||||||
|
shards = sorted(shard_dir.glob("shard_*.parquet"))
|
||||||
|
if not shards:
|
||||||
|
raise FileNotFoundError(f"No shards found in {shard_dir}")
|
||||||
|
|
||||||
|
# Log shard modification time range to help detect stale files
|
||||||
|
mtimes = [os.path.getmtime(s) for s in shards]
|
||||||
|
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
|
||||||
|
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
|
||||||
|
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
|
||||||
|
|
||||||
|
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
|
||||||
|
df = df.sort_values("index").reset_index(drop=True)
|
||||||
|
|
||||||
|
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||||
|
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||||
|
|
||||||
|
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
|
||||||
|
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
pq.write_table(table, out_path)
|
||||||
|
logging.info(f"Saved {len(df)} rows to {out_path}")
|
||||||
|
|
||||||
|
for col in ["progress_sparse", "progress_dense"]:
|
||||||
|
if col in df.columns:
|
||||||
|
v = df[col].dropna()
|
||||||
|
logging.info(
|
||||||
|
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.push_to_hub:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
hub_path = "sarm_progress.parquet"
|
||||||
|
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=str(out_path),
|
||||||
|
path_in_repo=hub_path,
|
||||||
|
repo_id=self.repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def make_compute_executor(
|
||||||
|
repo_id,
|
||||||
|
reward_model_path,
|
||||||
|
stride,
|
||||||
|
head_mode,
|
||||||
|
device,
|
||||||
|
shard_dir,
|
||||||
|
logs_dir,
|
||||||
|
job_name,
|
||||||
|
slurm,
|
||||||
|
workers,
|
||||||
|
partition,
|
||||||
|
cpus_per_task,
|
||||||
|
mem_per_cpu,
|
||||||
|
):
|
||||||
|
kwargs = {
|
||||||
|
"pipeline": [
|
||||||
|
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
|
||||||
|
],
|
||||||
|
"logging_dir": str(logs_dir / job_name),
|
||||||
|
}
|
||||||
|
|
||||||
|
if slurm:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"job_name": job_name,
|
||||||
|
"tasks": workers,
|
||||||
|
"workers": workers,
|
||||||
|
"time": "24:00:00",
|
||||||
|
"partition": partition,
|
||||||
|
"cpus_per_task": cpus_per_task,
|
||||||
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return SlurmPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
kwargs.update({"tasks": workers, "workers": 1})
|
||||||
|
return LocalPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_aggregate_executor(
|
||||||
|
repo_id,
|
||||||
|
reward_model_path,
|
||||||
|
shard_dir,
|
||||||
|
logs_dir,
|
||||||
|
job_name,
|
||||||
|
slurm,
|
||||||
|
partition,
|
||||||
|
cpus_per_task,
|
||||||
|
mem_per_cpu,
|
||||||
|
push_to_hub,
|
||||||
|
):
|
||||||
|
kwargs = {
|
||||||
|
"pipeline": [
|
||||||
|
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
|
||||||
|
],
|
||||||
|
"logging_dir": str(logs_dir / job_name),
|
||||||
|
}
|
||||||
|
|
||||||
|
if slurm:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"job_name": job_name,
|
||||||
|
"tasks": 1,
|
||||||
|
"workers": 1,
|
||||||
|
"time": "02:00:00",
|
||||||
|
"partition": partition,
|
||||||
|
"cpus_per_task": cpus_per_task,
|
||||||
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return SlurmPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
kwargs.update({"tasks": 1, "workers": 1})
|
||||||
|
return LocalPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_shared_args(p):
|
||||||
|
p.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--shard-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("rabc_shards"),
|
||||||
|
help="Directory to read/write per-rank parquet shards.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("logs"),
|
||||||
|
help="Directory for datatrove logs.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--job-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="SLURM job name (defaults to rabc_<subcommand>).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--slurm",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--partition",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="SLURM partition to submit to.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--cpus-per-task",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of CPUs per SLURM task.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--mem-per-cpu",
|
||||||
|
type=str,
|
||||||
|
default="4G",
|
||||||
|
help="Memory per CPU, e.g. '4G' or '1950M'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="SLURM-distributed SARM RA-BC annotation pipeline",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
# compute subcommand
|
||||||
|
cp = sub.add_parser(
|
||||||
|
"compute",
|
||||||
|
help="Distribute progress computation across SLURM workers.",
|
||||||
|
)
|
||||||
|
_add_shared_args(cp)
|
||||||
|
cp.add_argument(
|
||||||
|
"--reward-model-path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path or HF repo id of the SARM reward model.",
|
||||||
|
)
|
||||||
|
cp.add_argument(
|
||||||
|
"--stride",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
|
||||||
|
)
|
||||||
|
cp.add_argument(
|
||||||
|
"--head-mode",
|
||||||
|
type=str,
|
||||||
|
default="sparse",
|
||||||
|
choices=["sparse", "dense", "both"],
|
||||||
|
help="Which reward head(s) to compute.",
|
||||||
|
)
|
||||||
|
cp.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
|
||||||
|
)
|
||||||
|
cp.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of parallel SLURM tasks (one shard per worker).",
|
||||||
|
)
|
||||||
|
|
||||||
|
# aggregate subcommand
|
||||||
|
ap = sub.add_parser(
|
||||||
|
"aggregate",
|
||||||
|
help="Merge per-rank shards into a single sarm_progress.parquet.",
|
||||||
|
)
|
||||||
|
_add_shared_args(ap)
|
||||||
|
ap.add_argument(
|
||||||
|
"--reward-model-path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
job_name = args.job_name or f"rabc_{args.command}"
|
||||||
|
kwargs = vars(args)
|
||||||
|
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||||
|
kwargs["job_name"] = job_name
|
||||||
|
command = kwargs.pop("command")
|
||||||
|
|
||||||
|
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
|
||||||
|
|
||||||
|
executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -29,7 +29,7 @@ version = "0.4.5"
|
|||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
license = { text = "Apache-2.0" }
|
license = { text = "Apache-2.0" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.12"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
||||||
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
||||||
@@ -50,7 +50,8 @@ classifiers = [
|
|||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
"Topic :: Software Development :: Build Tools",
|
"Topic :: Software Development :: Build Tools",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
]
|
]
|
||||||
@@ -61,26 +62,28 @@ dependencies = [
|
|||||||
# Hugging Face dependencies
|
# Hugging Face dependencies
|
||||||
"datasets>=4.0.0,<5.0.0",
|
"datasets>=4.0.0,<5.0.0",
|
||||||
"diffusers>=0.27.2,<0.36.0",
|
"diffusers>=0.27.2,<0.36.0",
|
||||||
"huggingface-hub[cli]>=1.0.0,<2.0.0",
|
"huggingface-hub>=1.0.0,<2.0.0",
|
||||||
"accelerate>=1.10.0,<2.0.0",
|
"accelerate>=1.10.0,<2.0.0",
|
||||||
|
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
|
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||||
"setuptools>=71.0.0,<81.0.0",
|
"setuptools>=71.0.0,<81.0.0",
|
||||||
"cmake>=3.29.0.1,<4.2.0",
|
"cmake>=3.29.0.1,<4.2.0",
|
||||||
|
"packaging>=24.2,<26.0",
|
||||||
|
|
||||||
|
"torch>=2.2.1,<2.11.0",
|
||||||
|
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||||
|
"torchvision>=0.21.0,<0.26.0",
|
||||||
|
|
||||||
"einops>=0.8.0,<0.9.0",
|
"einops>=0.8.0,<0.9.0",
|
||||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||||
"av>=15.0.0,<16.0.0",
|
"av>=15.0.0,<16.0.0",
|
||||||
"jsonlines>=4.0.0,<5.0.0",
|
"jsonlines>=4.0.0,<5.0.0",
|
||||||
"packaging>=24.2,<26.0",
|
"pynput>=1.7.8,<1.9.0",
|
||||||
"pynput>=1.7.7,<1.9.0",
|
|
||||||
"pyserial>=3.5,<4.0",
|
"pyserial>=3.5,<4.0",
|
||||||
|
|
||||||
"wandb>=0.24.0,<0.25.0",
|
"wandb>=0.24.0,<0.25.0",
|
||||||
|
"draccus==0.10.0", # TODO: Relax version constraint
|
||||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
|
||||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
|
||||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
|
||||||
|
|
||||||
"draccus==0.10.0", # TODO: Remove ==
|
|
||||||
"gymnasium>=1.1.1,<2.0.0",
|
"gymnasium>=1.1.1,<2.0.0",
|
||||||
"rerun-sdk>=0.24.0,<0.27.0",
|
"rerun-sdk>=0.24.0,<0.27.0",
|
||||||
|
|
||||||
@@ -95,13 +98,14 @@ dependencies = [
|
|||||||
|
|
||||||
# Common
|
# Common
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||||
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||||
|
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||||
|
|
||||||
# Motors
|
# Motors
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||||
@@ -119,16 +123,16 @@ unitree_g1 = [
|
|||||||
"onnxruntime>=1.16.0,<2.0.0",
|
"onnxruntime>=1.16.0,<2.0.0",
|
||||||
"pin>=3.0.0,<4.0.0",
|
"pin>=3.0.0,<4.0.0",
|
||||||
"meshcat>=0.3.0,<0.4.0",
|
"meshcat>=0.3.0,<0.4.0",
|
||||||
"matplotlib>=3.9.0,<4.0.0",
|
"lerobot[matplotlib-dep]",
|
||||||
"casadi>=3.6.0,<4.0.0",
|
"casadi>=3.6.0,<4.0.0",
|
||||||
]
|
]
|
||||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||||
kinematics = ["lerobot[placo-dep]"]
|
kinematics = ["lerobot[placo-dep]"]
|
||||||
intelrealsense = [
|
intelrealsense = [
|
||||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
|
||||||
]
|
]
|
||||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
|
||||||
|
|
||||||
# Policies
|
# Policies
|
||||||
wallx = [
|
wallx = [
|
||||||
@@ -151,12 +155,12 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "lerobot[qwen-vl-utils-dep]"]
|
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
@@ -165,13 +169,19 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
|
|||||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||||
|
|
||||||
# Simulation
|
# Simulation
|
||||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||||
|
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
|
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||||
metaworld = ["metaworld==3.0.0"]
|
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||||
|
|
||||||
# All
|
# All
|
||||||
all = [
|
all = [
|
||||||
|
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
|
||||||
|
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
|
||||||
|
# helps pip's resolver converge by constraining scipy early, before it encounters
|
||||||
|
# the loose scipy requirements from transitive deps like dm-control and metaworld.
|
||||||
|
"scipy>=1.14.0,<2.0.0",
|
||||||
"lerobot[dynamixel]",
|
"lerobot[dynamixel]",
|
||||||
"lerobot[gamepad]",
|
"lerobot[gamepad]",
|
||||||
"lerobot[hopejr]",
|
"lerobot[hopejr]",
|
||||||
@@ -192,7 +202,7 @@ all = [
|
|||||||
"lerobot[aloha]",
|
"lerobot[aloha]",
|
||||||
"lerobot[pusht]",
|
"lerobot[pusht]",
|
||||||
"lerobot[phone]",
|
"lerobot[phone]",
|
||||||
"lerobot[libero]",
|
"lerobot[libero]; sys_platform == 'linux'",
|
||||||
"lerobot[metaworld]",
|
"lerobot[metaworld]",
|
||||||
"lerobot[sarm]",
|
"lerobot[sarm]",
|
||||||
"lerobot[peft]",
|
"lerobot[peft]",
|
||||||
@@ -224,7 +234,7 @@ lerobot = ["envs/*.json"]
|
|||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py310"
|
target-version = "py312"
|
||||||
line-length = 110
|
line-length = 110
|
||||||
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||||
|
|
||||||
@@ -316,7 +326,7 @@ default.extend-ignore-identifiers-re = [
|
|||||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
python_version = "3.10"
|
python_version = "3.12"
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
follow_imports = "skip"
|
follow_imports = "skip"
|
||||||
# warn_return_any = true
|
# warn_return_any = true
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from collections import deque
|
|||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable, Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -78,8 +78,6 @@ DEFAULT_FEATURES = {
|
|||||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
}
|
}
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||||
metadata = pq.read_metadata(parquet_path)
|
metadata = pq.read_metadata(parquet_path)
|
||||||
@@ -1234,7 +1232,7 @@ class LookAheadError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Backtrackable(Generic[T]):
|
class Backtrackable[T]:
|
||||||
"""
|
"""
|
||||||
Wrap any iterator/iterable so you can step back up to `history` items
|
Wrap any iterator/iterable so you can step back up to `history` items
|
||||||
and look ahead up to `lookahead` items.
|
and look ahead up to `lookahead` items.
|
||||||
|
|||||||
@@ -228,7 +228,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
|||||||
|
|
||||||
# Reset for the next file
|
# Reset for the next file
|
||||||
size_in_mb = 0
|
size_in_mb = 0
|
||||||
num_frames += ep_num_frames # Still need to accumulate total frames
|
|
||||||
paths_to_cat = []
|
paths_to_cat = []
|
||||||
|
|
||||||
# Now create metadata with correct chunk/file indices
|
# Now create metadata with correct chunk/file indices
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Protocol, TypeAlias
|
from typing import Protocol
|
||||||
|
|
||||||
import serial
|
import serial
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
@@ -38,8 +38,8 @@ from tqdm import tqdm
|
|||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||||
|
|
||||||
NameOrID: TypeAlias = str | int
|
type NameOrID = str | int
|
||||||
Value: TypeAlias = int | float
|
type Value = int | float
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
|
|||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
# Backward compatibility alias
|
||||||
MotorsBus: TypeAlias = SerialMotorsBus
|
MotorsBus = SerialMotorsBus
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
|||||||
@@ -4,10 +4,9 @@
|
|||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.image_processing_utils import (
|
from transformers.image_processing_utils import (
|
||||||
BatchFeature,
|
BatchFeature,
|
||||||
get_patch_output_size,
|
get_patch_output_size,
|
||||||
@@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _resize_for_patching(
|
def _resize_for_patching(
|
||||||
self,
|
self,
|
||||||
image: "torch.Tensor",
|
image: torch.Tensor,
|
||||||
target_resolution: tuple,
|
target_resolution: tuple,
|
||||||
interpolation: "F.InterpolationMode",
|
interpolation: F.InterpolationMode,
|
||||||
input_data_format: ChannelDimension,
|
input_data_format: ChannelDimension,
|
||||||
) -> "torch.Tensor":
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||||
|
|
||||||
@@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
return best_ratio
|
return best_ratio
|
||||||
|
|
||||||
def _pad_for_patching(
|
def _pad_for_patching(
|
||||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||||
) -> "torch.Tensor":
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad an image to a target resolution while maintaining aspect ratio.
|
Pad an image to a target resolution while maintaining aspect ratio.
|
||||||
"""
|
"""
|
||||||
@@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _get_image_patches(
|
def _get_image_patches(
|
||||||
self,
|
self,
|
||||||
image: "torch.Tensor",
|
image: torch.Tensor,
|
||||||
min_num: int,
|
min_num: int,
|
||||||
max_num: int,
|
max_num: int,
|
||||||
size: tuple,
|
size: tuple,
|
||||||
tile_size: int,
|
tile_size: int,
|
||||||
use_thumbnail: bool,
|
use_thumbnail: bool,
|
||||||
interpolation: "F.InterpolationMode",
|
interpolation: F.InterpolationMode,
|
||||||
pad_during_tiling: bool,
|
pad_during_tiling: bool,
|
||||||
) -> list["torch.Tensor"]:
|
) -> list[torch.Tensor]:
|
||||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||||
orig_height, orig_width = image_size
|
orig_height, orig_width = image_size
|
||||||
aspect_ratio = orig_width / orig_height
|
aspect_ratio = orig_width / orig_height
|
||||||
@@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _pad_for_batching(
|
def _pad_for_batching(
|
||||||
self,
|
self,
|
||||||
pixel_values: list["torch.Tensor"],
|
pixel_values: list[torch.Tensor],
|
||||||
) -> list["torch.Tensor"]:
|
) -> list[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||||
|
|
||||||
@@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
images: list["torch.Tensor"],
|
images: list[torch.Tensor],
|
||||||
do_resize: bool,
|
do_resize: bool,
|
||||||
size: SizeDict,
|
size: SizeDict,
|
||||||
max_dynamic_tiles: int,
|
max_dynamic_tiles: int,
|
||||||
min_dynamic_tiles: int,
|
min_dynamic_tiles: int,
|
||||||
use_thumbnail: bool,
|
use_thumbnail: bool,
|
||||||
pad_during_tiling: bool,
|
pad_during_tiling: bool,
|
||||||
interpolation: Optional["F.InterpolationMode"],
|
interpolation: F.InterpolationMode | None,
|
||||||
do_center_crop: bool,
|
do_center_crop: bool,
|
||||||
crop_size: SizeDict,
|
crop_size: SizeDict,
|
||||||
do_rescale: bool,
|
do_rescale: bool,
|
||||||
|
|||||||
@@ -20,12 +20,11 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
|||||||
@@ -20,12 +20,11 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
|||||||
@@ -19,13 +19,12 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import os
|
|||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import TypedDict, TypeVar
|
from typing import TypedDict, TypeVar, Unpack
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
import safetensors
|
import safetensors
|
||||||
@@ -28,7 +28,6 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|||||||
@@ -54,12 +54,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypeAlias, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -36,10 +36,10 @@ class TransitionKey(str, Enum):
|
|||||||
COMPLEMENTARY_DATA = "complementary_data"
|
COMPLEMENTARY_DATA = "complementary_data"
|
||||||
|
|
||||||
|
|
||||||
PolicyAction: TypeAlias = torch.Tensor
|
PolicyAction = torch.Tensor
|
||||||
RobotAction: TypeAlias = dict[str, Any]
|
RobotAction = dict[str, Any]
|
||||||
EnvAction: TypeAlias = np.ndarray
|
EnvAction = np.ndarray
|
||||||
RobotObservation: TypeAlias = dict[str, Any]
|
RobotObservation = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
EnvTransition = TypedDict(
|
EnvTransition = TypedDict(
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
from typing import Any, TypedDict, TypeVar, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||||
|
|
||||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||||
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
|||||||
|
|
||||||
|
|
||||||
# Type aliases for semantic clarity.
|
# Type aliases for semantic clarity.
|
||||||
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||||
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||||
|
|
||||||
|
|
||||||
class ObservationProcessorStep(ProcessorStep, ABC):
|
class ObservationProcessorStep(ProcessorStep, ABC):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
from lerobot.cameras import CameraConfig
|
||||||
|
|
||||||
@@ -50,5 +49,5 @@ class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
SO100FollowerConfig = SOFollowerRobotConfig
|
||||||
SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
SO101FollowerConfig = SOFollowerRobotConfig
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
@@ -230,5 +229,5 @@ class SOFollower(Robot):
|
|||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
|
||||||
SO100Follower: TypeAlias = SOFollower
|
SO100Follower = SOFollower
|
||||||
SO101Follower: TypeAlias = SOFollower
|
SO101Follower = SOFollower
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from ..config import TeleoperatorConfig
|
from ..config import TeleoperatorConfig
|
||||||
|
|
||||||
@@ -38,5 +37,5 @@ class SOLeaderTeleopConfig(TeleoperatorConfig, SOLeaderConfig):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
SO100LeaderConfig: TypeAlias = SOLeaderTeleopConfig
|
SO100LeaderConfig = SOLeaderTeleopConfig
|
||||||
SO101LeaderConfig: TypeAlias = SOLeaderTeleopConfig
|
SO101LeaderConfig = SOLeaderTeleopConfig
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.motors.feetech import (
|
from lerobot.motors.feetech import (
|
||||||
@@ -156,5 +155,5 @@ class SOLeader(Teleoperator):
|
|||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
|
||||||
SO100Leader: TypeAlias = SOLeader
|
SO100Leader = SOLeader
|
||||||
SO101Leader: TypeAlias = SOLeader
|
SO101Leader = SOLeader
|
||||||
|
|||||||
@@ -16,12 +16,10 @@
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import imageio
|
import imageio
|
||||||
|
|
||||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||||
T = TypeVar("T", bound=JsonLike)
|
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
@@ -33,7 +31,7 @@ def write_video(video_path, stacked_frames, fps):
|
|||||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
def deserialize_json_into_object[T: JsonLike](fpath: Path, obj: T) -> T:
|
||||||
"""
|
"""
|
||||||
Loads the JSON data from `fpath` and recursively fills `obj` with the
|
Loads the JSON data from `fpath` and recursively fills `obj` with the
|
||||||
corresponding values (strictly matching structure and types).
|
corresponding values (strictly matching structure and types).
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from lerobot.utils.constants import (
|
|||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
OBS_STATE,
|
OBS_STATE,
|
||||||
) # noqa: E402
|
) # noqa: E402
|
||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
DUMMY_ACTION_DIM = 7
|
DUMMY_ACTION_DIM = 7
|
||||||
@@ -65,6 +65,7 @@ EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000]
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def set_seed_all(seed: int):
|
def set_seed_all(seed: int):
|
||||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -82,6 +83,7 @@ def set_seed_all(seed: int):
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def instantiate_lerobot_pi0_fast(
|
def instantiate_lerobot_pi0_fast(
|
||||||
from_pretrained: bool = False,
|
from_pretrained: bool = False,
|
||||||
model_path: str = MODEL_PATH_LEROBOT,
|
model_path: str = MODEL_PATH_LEROBOT,
|
||||||
@@ -125,6 +127,7 @@ def instantiate_lerobot_pi0_fast(
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def create_dummy_data(device=DEVICE):
|
def create_dummy_data(device=DEVICE):
|
||||||
"""Create dummy data for testing both implementations."""
|
"""Create dummy data for testing both implementations."""
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
@@ -157,6 +160,7 @@ def create_dummy_data(device=DEVICE):
|
|||||||
# Pytest fixtures
|
# Pytest fixtures
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def pi0_fast_components():
|
def pi0_fast_components():
|
||||||
"""Fixture to instantiate and provide all PI0Fast components for tests."""
|
"""Fixture to instantiate and provide all PI0Fast components for tests."""
|
||||||
print(f"\nTesting with DEVICE='{DEVICE}'")
|
print(f"\nTesting with DEVICE='{DEVICE}'")
|
||||||
@@ -168,6 +172,7 @@ def pi0_fast_components():
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def policy(pi0_fast_components):
|
def policy(pi0_fast_components):
|
||||||
"""Fixture to provide the PI0Fast policy for tests."""
|
"""Fixture to provide the PI0Fast policy for tests."""
|
||||||
return pi0_fast_components[0]
|
return pi0_fast_components[0]
|
||||||
@@ -175,12 +180,14 @@ def policy(pi0_fast_components):
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def preprocessor(pi0_fast_components):
|
def preprocessor(pi0_fast_components):
|
||||||
"""Fixture to provide the PI0Fast preprocessor for tests."""
|
"""Fixture to provide the PI0Fast preprocessor for tests."""
|
||||||
return pi0_fast_components[1]
|
return pi0_fast_components[1]
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_pi0_fast_preprocessor_alignment(policy, preprocessor):
|
def test_pi0_fast_preprocessor_alignment(policy, preprocessor):
|
||||||
"""Test that LeRobot PI0Fast preprocessor produces expected outputs."""
|
"""Test that LeRobot PI0Fast preprocessor produces expected outputs."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -228,6 +235,7 @@ def test_pi0_fast_preprocessor_alignment(policy, preprocessor):
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_pi0_fast_action_generation(policy, preprocessor):
|
def test_pi0_fast_action_generation(policy, preprocessor):
|
||||||
"""Test PI0Fast LeRobot implementation generates expected actions."""
|
"""Test PI0Fast LeRobot implementation generates expected actions."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -306,6 +314,7 @@ def test_pi0_fast_action_generation(policy, preprocessor):
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_pi0_fast_inference_reproducibility(policy, preprocessor):
|
def test_pi0_fast_inference_reproducibility(policy, preprocessor):
|
||||||
"""Test that PI0Fast inference is reproducible with the same seed."""
|
"""Test that PI0Fast inference is reproducible with the same seed."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -347,6 +356,7 @@ def test_pi0_fast_inference_reproducibility(policy, preprocessor):
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_pi0_fast_forward_pass_logits(policy, preprocessor):
|
def test_pi0_fast_forward_pass_logits(policy, preprocessor):
|
||||||
"""Test PI0Fast forward pass and compare logits against expected values."""
|
"""Test PI0Fast forward pass and compare logits against expected values."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -396,6 +406,7 @@ def test_pi0_fast_forward_pass_logits(policy, preprocessor):
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_pi0_fast_action_token_sampling(policy, preprocessor):
|
def test_pi0_fast_action_token_sampling(policy, preprocessor):
|
||||||
"""Test PI0Fast action token sampling (autoregressive decoding)."""
|
"""Test PI0Fast action token sampling (autoregressive decoding)."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
@@ -452,6 +463,7 @@ def test_pi0_fast_action_token_sampling(policy, preprocessor):
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_pi0_fast_detokenization(policy, preprocessor):
|
def test_pi0_fast_detokenization(policy, preprocessor):
|
||||||
"""Test PI0Fast action detokenization (FAST decoding)."""
|
"""Test PI0Fast action detokenization (FAST decoding)."""
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
|
|||||||
@@ -14,10 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
|
"""Test script to verify PI0 policy integration with LeRobot"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||||
from lerobot.policies.pi0 import ( # noqa: E402
|
from lerobot.policies.pi0 import ( # noqa: E402
|
||||||
PI0Config,
|
PI0Config,
|
||||||
@@ -25,10 +28,11 @@ from lerobot.policies.pi0 import ( # noqa: E402
|
|||||||
make_pi0_pre_post_processors, # noqa: E402
|
make_pi0_pre_post_processors, # noqa: E402
|
||||||
)
|
)
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_policy_instantiation():
|
def test_policy_instantiation():
|
||||||
# Create config
|
# Create config
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@@ -105,6 +109,7 @@ def test_policy_instantiation():
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_config_creation():
|
def test_config_creation():
|
||||||
"""Test policy config creation through factory."""
|
"""Test policy config creation through factory."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,10 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
|
"""Test script to verify PI0.5 (pi05) support in PI0 policy"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||||
from lerobot.policies.pi05 import ( # noqa: E402
|
from lerobot.policies.pi05 import ( # noqa: E402
|
||||||
PI05Config,
|
PI05Config,
|
||||||
@@ -25,10 +28,11 @@ from lerobot.policies.pi05 import ( # noqa: E402
|
|||||||
make_pi05_pre_post_processors, # noqa: E402
|
make_pi05_pre_post_processors, # noqa: E402
|
||||||
)
|
)
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_policy_instantiation():
|
def test_policy_instantiation():
|
||||||
# Create config
|
# Create config
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@@ -141,6 +145,7 @@ def test_policy_instantiation():
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_config_creation():
|
def test_config_creation():
|
||||||
"""Test policy config creation through factory."""
|
"""Test policy config creation through factory."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|||||||
@@ -143,12 +143,18 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||||
and for now we add tests as we see fit.
|
and for now we add tests as we see fit.
|
||||||
"""
|
"""
|
||||||
|
if policy_name == "vqbet" and DEVICE == "mps":
|
||||||
|
pytest.skip("VQBet does not support MPS backend")
|
||||||
|
if policy_name == "act" and "aloha" in ds_repo_id and DEVICE == "mps":
|
||||||
|
pytest.skip("ACT with aloha has batch mutation issues on MPS")
|
||||||
|
|
||||||
train_cfg = TrainPipelineConfig(
|
train_cfg = TrainPipelineConfig(
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
||||||
env=make_env_config(env_name, **env_kwargs),
|
env=make_env_config(env_name, **env_kwargs),
|
||||||
)
|
)
|
||||||
|
train_cfg.policy.device = DEVICE
|
||||||
train_cfg.validate()
|
train_cfg.validate()
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
@@ -227,6 +233,7 @@ def test_act_backbone_lr():
|
|||||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
||||||
)
|
)
|
||||||
|
cfg.policy.device = DEVICE
|
||||||
cfg.validate() # Needed for auto-setting some parameters
|
cfg.validate() # Needed for auto-setting some parameters
|
||||||
|
|
||||||
assert cfg.policy.optimizer_lr == 0.01
|
assert cfg.policy.optimizer_lr == 0.01
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
"""Test script to verify Wall-X policy integration with LeRobot"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -29,10 +29,11 @@ from lerobot.policies.wall_x import WallXConfig # noqa: E402
|
|||||||
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402
|
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402
|
||||||
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402
|
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_policy_instantiation():
|
def test_policy_instantiation():
|
||||||
# Create config
|
# Create config
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@@ -118,6 +119,7 @@ def test_policy_instantiation():
|
|||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
|
@require_hf_token
|
||||||
def test_config_creation():
|
def test_config_creation():
|
||||||
"""Test policy config creation through factory."""
|
"""Test policy config creation through factory."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation"""
|
||||||
# ruff: noqa: E402
|
# ruff: noqa: E402
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|||||||
@@ -1870,9 +1870,7 @@ class NonCallableStep(ProcessorStep):
|
|||||||
|
|
||||||
def test_construction_rejects_step_without_call():
|
def test_construction_rejects_step_without_call():
|
||||||
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
||||||
with pytest.raises(
|
with pytest.raises(TypeError, match=r"Can't instantiate abstract class NonCallableStep"):
|
||||||
TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_"
|
|
||||||
):
|
|
||||||
DataProcessorPipeline([NonCallableStep()])
|
DataProcessorPipeline([NonCallableStep()])
|
||||||
|
|
||||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ import torch
|
|||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots
|
from lerobot import available_cameras, available_motors, available_robots
|
||||||
from lerobot.utils.import_utils import is_package_available
|
from lerobot.utils.import_utils import is_package_available
|
||||||
|
from lerobot.utils.utils import auto_select_torch_device
|
||||||
|
|
||||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device()))
|
||||||
|
|
||||||
TEST_ROBOT_TYPES = []
|
TEST_ROBOT_TYPES = []
|
||||||
for robot_type in available_robots:
|
for robot_type in available_robots:
|
||||||
@@ -107,6 +108,22 @@ def require_cuda(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_hf_token(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if no Hugging Face Hub token is available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
from huggingface_hub import get_token
|
||||||
|
|
||||||
|
if get_token() is None:
|
||||||
|
pytest.skip("requires HF token for gated model access")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_env(func):
|
def require_env(func):
|
||||||
"""
|
"""
|
||||||
Decorator that skips the test if the required environment package is not installed.
|
Decorator that skips the test if the required environment package is not installed.
|
||||||
|
|||||||
Reference in New Issue
Block a user