Compare commits

..

1 Commits

Author SHA1 Message Date
Jade Choghari
cbb380df34 draft changes 2025-12-26 14:06:30 +00:00
220 changed files with 3746 additions and 11125 deletions

View File

@@ -22,21 +22,20 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S
- Short, concrete bullets of the modifications (files/behaviour). - Short, concrete bullets of the modifications (files/behaviour).
- Short note if this introduces breaking changes and migration steps. - Short note if this introduces breaking changes and migration steps.
## How was this tested (or how to run locally) ## How was this tested
- Tests added: list new tests or test files. - Tests added: list new tests or test files.
- Manual checks / dataset runs performed. - Manual checks / dataset runs performed.
- Instructions for the reviewer
Example: ## How to run locally (reviewer)
- Ran the relevant tests: - Run the relevant tests:
```bash ```bash
pytest -q tests/ -k <keyword> pytest -q tests/ -k <keyword>
``` ```
- Reproduce with a quick example or CLI (if applicable): - Run a quick example or CLI (if applicable):
```bash ```bash
lerobot-train --some.option=true lerobot-train --some.option=true

View File

@@ -18,11 +18,6 @@ name: Documentation
on: on:
# Allows running this workflow manually from the Actions tab # Allows running this workflow manually from the Actions tab
workflow_dispatch: workflow_dispatch:
inputs:
version:
description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build'
required: false
type: string
# Triggers the workflow on push events to main for the docs folder # Triggers the workflow on push events to main for the docs folder
push: push:
@@ -59,13 +54,7 @@ jobs:
with: with:
commit_sha: ${{ github.sha }} commit_sha: ${{ github.sha }}
package: lerobot package: lerobot
additional_args: >- additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
--not_python_module
${{
(github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) ||
(inputs.version != '' && format('--version {0}', inputs.version)) ||
''
}}
secrets: secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }} token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@@ -186,18 +186,15 @@ jobs:
steps: steps:
- name: Get Docker Hub Token and Delete Image - name: Get Docker Hub Token and Delete Image
# zizmor: ignore[template-injection] # zizmor: ignore[template-injection]
env:
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
IMAGE_FULL: ${{ needs.build-and-push-docker.outputs.image_tag }}
run: | run: |
IMAGE_NAME=$(echo "$IMAGE_FULL" | cut -d':' -f1) IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2-) IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG" echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
TOKEN=$(curl -s -H "Content-Type: application/json" \ TOKEN=$(curl -s -H "Content-Type: application/json" \
-X POST \ -X POST \
-d "{\"username\": \"$DOCKERHUB_LEROBOT_USERNAME\", \"password\": \"$DOCKERHUB_LEROBOT_PASSWORD\"}" \ -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
https://hub.docker.com/v2/users/login/ | jq -r .token) https://hub.docker.com/v2/users/login/ | jq -r .token)
if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
@@ -208,7 +205,7 @@ jobs:
HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \ HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
-H "Authorization: JWT ${TOKEN}" \ -H "Authorization: JWT ${TOKEN}" \
-X DELETE \ -X DELETE \
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/$IMAGE_TAG) https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
if [ "$HTTP_RESPONSE" -eq 204 ]; then if [ "$HTTP_RESPONSE" -eq 204 ]; then
echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG" echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"

View File

@@ -20,8 +20,8 @@ on:
workflow_dispatch: workflow_dispatch:
# Run on the 1st and 15th of every month at 09:00 UTC # Run on the 1st and 15th of every month at 09:00 UTC
# schedule: schedule:
# - cron: '0 2 1,15 * *' - cron: '0 2 1,15 * *'
permissions: permissions:
contents: read contents: read
@@ -162,19 +162,15 @@ jobs:
steps: steps:
- name: Get Docker Hub Token and Delete Image - name: Get Docker Hub Token and Delete Image
# zizmor: ignore[template-injection] # zizmor: ignore[template-injection]
env:
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
IMAGE_FULL: ${{ needs.build-and-push-docker.outputs.image_tag }}
run: | run: |
IMAGE_NAME=$(echo "$IMAGE_FULL" | cut -d':' -f1) IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2) IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG" echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
TOKEN=$(curl -s -H "Content-Type: application/json" \ TOKEN=$(curl -s -H "Content-Type: application/json" \
-X POST \ -X POST \
-d "{\"username\": \"$DOCKERHUB_LEROBOT_USERNAME\", \"password\": \"$DOCKERHUB_LEROBOT_PASSWORD\"}" \ -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
https://hub.docker.com/v2/users/login/ | jq -r .token) https://hub.docker.com/v2/users/login/ | jq -r .token)
if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
@@ -185,7 +181,7 @@ jobs:
HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \ HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
-H "Authorization: JWT ${TOKEN}" \ -H "Authorization: JWT ${TOKEN}" \
-X DELETE \ -X DELETE \
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/$IMAGE_TAG) https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
if [ "$HTTP_RESPONSE" -eq 204 ]; then if [ "$HTTP_RESPONSE" -eq 204 ]; then
echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG" echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"

View File

@@ -14,7 +14,7 @@ You can contribute in many ways:
- **Documentation:** Improve examples, guides, and docstrings. - **Documentation:** Improve examples, guides, and docstrings.
- **Feedback:** Submit tickets related to bugs or desired new features. - **Feedback:** Submit tickets related to bugs or desired new features.
If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f). If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
## Development Setup ## Development Setup

View File

@@ -10,7 +10,6 @@
[![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/) [![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/)
[![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/) [![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
[![Discord](https://img.shields.io/badge/Discord-Join_Us-5865F2?style=flat&logo=discord&logoColor=white)](https://discord.gg/q8Dzzpym3f)
</div> </div>
@@ -100,11 +99,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet --dataset.repo_id=lerobot/aloha_mobile_cabinet
``` ```
| Category | Models | | Category | Models |
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | | **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | | **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | | **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
@@ -128,8 +127,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
## Resources ## Resources
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API. - **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players. - **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community.
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.

View File

@@ -1,48 +0,0 @@
# Security Policy
## Project Status & Philosophy
`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasnt been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues.
Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions.
## Reporting a Vulnerability
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab.
The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
#### Hugging Face Security Team
Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps.
#### Open Source Disclosures
If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software.
## Supported Versions
Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch).
| Version | Supported |
| -------- | --------- |
| Latest | ✅ |
| < Latest | ❌ |
## Secure Usage Guidelines
`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe.
### Remote Artefacts (Weights & Policies)
Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format.
`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots.
To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files.
### Remote Code
Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code.
Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository.

View File

@@ -73,7 +73,7 @@ ENV HOME=/home/user_lerobot \
RUN uv venv --python python${PYTHON_VERSION} RUN uv venv --python python${PYTHON_VERSION}
# Install Python dependencies for caching # Install Python dependencies for caching
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
COPY --chown=user_lerobot:user_lerobot src/ src/ COPY --chown=user_lerobot:user_lerobot src/ src/
ARG UNBOUND_DEPS=false ARG UNBOUND_DEPS=false

View File

@@ -59,7 +59,7 @@ ENV HOME=/home/user_lerobot \
RUN uv venv RUN uv venv
# Install Python dependencies for caching # Install Python dependencies for caching
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
COPY --chown=user_lerobot:user_lerobot src/ src/ COPY --chown=user_lerobot:user_lerobot src/ src/
ARG UNBOUND_DEPS=false ARG UNBOUND_DEPS=false

View File

@@ -19,8 +19,6 @@
title: Train RL in Simulation title: Train RL in Simulation
- local: multi_gpu_training - local: multi_gpu_training
title: Multi GPU training title: Multi GPU training
- local: peft_training
title: Training with PEFT (e.g., LoRA)
title: "Tutorials" title: "Tutorials"
- sections: - sections:
- local: lerobot-dataset-v3 - local: lerobot-dataset-v3
@@ -37,8 +35,6 @@
title: SmolVLA title: SmolVLA
- local: pi0 - local: pi0
title: π₀ (Pi0) title: π₀ (Pi0)
- local: pi0fast
title: π₀-FAST (Pi0Fast)
- local: pi05 - local: pi05
title: π₀.₅ (Pi05) title: π₀.₅ (Pi05)
- local: groot - local: groot
@@ -63,8 +59,6 @@
title: Environments from the Hub title: Environments from the Hub
- local: envhub_leisaac - local: envhub_leisaac
title: Control & Train Robots in Sim (LeIsaac) title: Control & Train Robots in Sim (LeIsaac)
- local: envhub_isaaclab_arena
title: NVIDIA IsaacLab Arena Environments
- local: libero - local: libero
title: Using Libero title: Using Libero
- local: metaworld - local: metaworld

View File

@@ -169,7 +169,7 @@ python -m lerobot.async_inference.robot_client \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
import threading import threading
from lerobot.robots.so_follower import SO100FollowerConfig from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.async_inference.configs import RobotClientConfig from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.robot_client import RobotClient from lerobot.async_inference.robot_client import RobotClient
@@ -195,7 +195,6 @@ client_cfg = RobotClientConfig(
robot=robot_cfg, robot=robot_cfg,
server_address="localhost:8080", server_address="localhost:8080",
policy_device="mps", policy_device="mps",
client_device="cpu",
policy_type="smolvla", policy_type="smolvla",
pretrained_name_or_path="<user>/smolvla_async", pretrained_name_or_path="<user>/smolvla_async",
chunk_size_threshold=0.5, chunk_size_threshold=0.5,

View File

@@ -12,42 +12,23 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
### Setting Up the Frodobots SDK ### Setting Up the Frodobots SDK
The robot needs the [Frodobots SDK](https://github.com/frodobots-org/earth-rovers-sdk) running on your computer. Here's how: The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
1. Download and install the SDK: 1. Download and install the SDK:
```bash ```bash
git clone https://github.com/frodobots-org/earth-rovers-sdk.git git clone https://github.com/Frodobots/earth-rovers-sdk.git
cd earth-rovers-sdk cd earth-rovers-sdk
pip install -r requirements.txt pip install -r requirements.txt
``` ```
2. Save Credentials: 2. Start the SDK:
Write your .env variables with the SDK API key and bot name provided by the Frodobots team.
```bash
SDK_API_TOKEN=your_sdk_api_token_here
BOT_SLUG=your_bot_slug_here
CHROME_EXECUTABLE_PATH=/path/to/chrome_or_chromium
# Default value is MAP_ZOOM_LEVEL=18 https://wiki.openstreetmap.org/wiki/Zoom_levels
MAP_ZOOM_LEVEL=18
MISSION_SLUG=your_mission_slug_here
# Image quality between 0.1 and 1.0 (default: 0.8)
# Recommended: 0.8 for better performance
IMAGE_QUALITY=0.8
# Image format: jpeg, png or webp (default: png)
# Recommended: jpeg for better performance and lower bandwidth usage
IMAGE_FORMAT=jpeg
```
3. Start the SDK:
```bash ```bash
hypercorn main:app --reload hypercorn main:app --reload
``` ```
4. Open your web browser and go to `http://localhost:8000`, then click "Join" 3. Open your web browser and go to `http://localhost:8000`, then click "Join"
The SDK gives you: The SDK gives you:

View File

@@ -2,32 +2,14 @@
The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community. The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community.
## What is EnvHub? ## Overview
EnvHub lets you create custom robotics simulation environments with your own robot models and scenarios, and make them easily usable by anyone through the LeRobot framework. With EnvHub, you can:
EnvHub packages are stored on the Hugging Face Hub, and can be seamlessly pulled and used in your AI robotics projects through LeRobot with a single line of code. - Load environments from the Hub instantly
- Share your custom simulation tasks with the community
Thanks to EnvHub, you can: - Version control your environments using Git
- Distribute complex physics simulations without packaging hassles
1. **Create and publish environments** to the Hugging Face Hub as Git repositories, and distribute complex physics simulations without packaging hassles
2. **Load environments** dynamically, without installing them as packages
3. **Version and track** environment changes using Git semantics
4. **Discover** new simulation tasks shared by the community
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, or create your own custom robot and environment without worrying about dependency conflicts or complex installation procedures.
When you create an EnvHub package, you can build anything you want inside it and use any simulation tool you like: this is your own space to play with. The only requirement is that the package contains an `env.py` file that defines the environment and allows LeRobot to load and use your EnvHub package.
This `env.py` file needs to expose a small API so LeRobot can load and run it. In particular, you must provide a `make_env(n_envs: int = 1, use_async_envs: bool = False)` or `make_env(n_envs: int = 1, use_async_envs: bool = False, cfg: EnvConfig)` function, which is the main entry point for LeRobot. It should return one of:
- A `gym.vector.VectorEnv` (most common)
- A single `gym.Env` (will be automatically wrapped)
- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks)
You can also pass an `EnvConfig` object to `make_env` to configure the environment (e.g. the number of environments, task, camera name, initial states, control mode, episode length, etc.).
Finally, your environment must implement the standard `gym.vector.VectorEnv` interface so it works with LeRobot, including methods like `reset` and `step`.
## Quick Start ## Quick Start
@@ -47,6 +29,17 @@ env = make_env("lerobot/cartpole-env", trust_remote_code=True)
hash for reproducibility and security. hash for reproducibility and security.
</Tip> </Tip>
## What is EnvHub?
EnvHub is a framework that allows researchers and developers to:
1. **Publish environments** to the Hugging Face Hub as Git repositories
2. **Load environments** dynamically without installing them as packages
3. **Version and track** environment changes using Git semantics
4. **Discover** new simulation tasks shared by the community
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures.
## Repository Structure ## Repository Structure
To make your environment loadable from the Hub, your repository must contain at minimum: To make your environment loadable from the Hub, your repository must contain at minimum:

View File

@@ -1,510 +0,0 @@
# NVIDIA IsaacLab Arena & LeRobot
LeRobot EnvHub now supports **GPU-accelerated simulation** with IsaacLab Arena for policy evaluation at scale.
Train and evaluate imitation learning policies with high-fidelity simulation — all integrated into the LeRobot ecosystem.
<img
src="https://huggingface.co/nvidia/isaaclab-arena-envs/resolve/main/assets/Gr1OpenMicrowaveEnvironment.png"
alt="IsaacLab Arena - GR1 Microwave Environment"
style={{ maxWidth: "100%", borderRadius: "8px", marginBottom: "1rem" }}
/>
[IsaacLab Arena](https://github.com/isaac-sim/IsaacLab-Arena) integrates with NVIDIA IsaacLab to provide:
- 🤖 **Humanoid embodiments**: GR1, G1, Galileo with various configurations
- 🎯 **Manipulation & loco-manipulation tasks**: Door opening, pick-and-place, button pressing, and more
- ⚡ **GPU-accelerated rollouts**: Parallel environment execution on NVIDIA GPUs
- 🖼️ **RTX Rendering**: Evaluate vision-based policies with realistic rendering, reflections and refractions
- 📦 **LeRobot-compatible datasets**: Ready for training with GR00T N1x, PI0, SmolVLA, ACT, and Diffusion policies
- 🔄 **EnvHub integration**: Load environments from HuggingFace EnvHub with one line
## Installation
### Prerequisites
Hardware requirements are shared with Isaac Sim, and are detailed in [Isaac Sim Requirements](https://docs.isaacsim.omniverse.nvidia.com/5.1.0/installation/requirements.html).
- NVIDIA GPU with CUDA support
- NVIDIA driver compatible with IsaacSim 5.1.0
- Linux (Ubuntu 22.04 / 24.04)
### Setup
```bash
# 1. Create conda environment
conda create -y -n lerobot-arena python=3.11
conda activate lerobot-arena
conda install -y -c conda-forge ffmpeg=7.1.1
# 2. Install Isaac Sim 5.1.0
pip install "isaacsim[all,extscache]==5.1.0" --extra-index-url https://pypi.nvidia.com
# Accept NVIDIA EULA (required)
export ACCEPT_EULA=Y
export PRIVACY_CONSENT=Y
# 3. Install IsaacLab 2.3.0
git clone https://github.com/isaac-sim/IsaacLab.git
cd IsaacLab
git checkout v2.3.0
./isaaclab.sh -i
cd ..
# 4. Install IsaacLab Arena
git clone https://github.com/isaac-sim/IsaacLab-Arena.git
cd IsaacLab-Arena
git checkout release/0.1.1
pip install -e .
cd ..
# 5. Install LeRobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e .
cd ..
# 6. Install additional dependencies
pip install onnxruntime==1.23.2 lightwheel-sdk==1.0.1 vuer[all]==0.0.70 qpsolvers==4.8.1
pip install numpy==1.26.0 # Isaac Sim 5.1 depends on numpy==1.26.0, this will be fixed in next release
```
## Evaluating Policies
### Pre-trained Policies
The following trained policies are available:
| Policy | Architecture | Task | Link |
| :-------------------------- | :----------- | :------------ | :----------------------------------------------------------------------- |
| pi05-arena-gr1-microwave | PI0.5 | GR1 Microwave | [HuggingFace](https://huggingface.co/nvidia/pi05-arena-gr1-microwave) |
| smolvla-arena-gr1-microwave | SmolVLA | GR1 Microwave | [HuggingFace](https://huggingface.co/nvidia/smolvla-arena-gr1-microwave) |
### Evaluate SmolVLA
```bash
pip install -e ".[smolvla]"
pip install numpy==1.26.0 # revert numpy to version 1.26
```
```bash
lerobot-eval \
--policy.path=nvidia/smolvla-arena-gr1-microwave \
--env.type=isaaclab_arena \
--env.hub_path=nvidia/isaaclab-arena-envs \
--rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \
--policy.device=cuda \
--env.environment=gr1_microwave \
--env.embodiment=gr1_pink \
--env.object=mustard_bottle \
--env.headless=false \
--env.enable_cameras=true \
--env.video=true \
--env.video_length=10 \
--env.video_interval=15 \
--env.state_keys=robot_joint_pos \
--env.camera_keys=robot_pov_cam_rgb \
--trust_remote_code=True \
--eval.batch_size=1
```
### Evaluate PI0.5
```bash
pip install -e ".[pi]"
pip install numpy==1.26.0 # revert numpy to version 1.26
```
<Tip>PI0.5 requires disabling torch compile for evaluation:</Tip>
```bash
TORCH_COMPILE_DISABLE=1 TORCHINDUCTOR_DISABLE=1 lerobot-eval \
--policy.path=nvidia/pi05-arena-gr1-microwave \
--env.type=isaaclab_arena \
--env.hub_path=nvidia/isaaclab-arena-envs \
--rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \
--policy.device=cuda \
--env.environment=gr1_microwave \
--env.embodiment=gr1_pink \
--env.object=mustard_bottle \
--env.headless=false \
--env.enable_cameras=true \
--env.video=true \
--env.video_length=15 \
--env.video_interval=15 \
--env.state_keys=robot_joint_pos \
--env.camera_keys=robot_pov_cam_rgb \
--trust_remote_code=True \
--eval.batch_size=1
```
<Tip>
To change the number of parallel environments, use the ```--eval.batch_size```
flag.
</Tip>
### What to Expect
During evaluation, you will see a progress bar showing the running success rate:
```
Stepping through eval batches: 8%|██████▍ | 4/50 [00:45<08:06, 10.58s/it, running_success_rate=25.0%]
```
### Video Recording
To enable video recording during evaluation, add the following flags to your command:
```bash
--env.video=true \
--env.video_length=15 \
--env.video_interval=15
```
For more details on video recording, see the [IsaacLab Recording Documentation](https://isaac-sim.github.io/IsaacLab/main/source/how-to/record_video.html).
<Tip>
When running headless with `--env.headless=true`, you must also enable cameras explicitly for camera enabled environments:
```bash
--env.headless=true --env.enable_cameras=true
```
</Tip>
### Output Directory
Evaluation videos are saved to the output directory with the following structure:
```
outputs/eval/<date>/<timestamp>_<env>_<policy>/videos/<task>_<env_id>/eval_episode_<n>.mp4
```
For example:
```
outputs/eval/2026-01-02/14-38-01_isaaclab_arena_smolvla/videos/gr1_microwave_0/eval_episode_0.mp4
```
## Training Policies
To learn more about training policies with LeRobot, please refer to the training documentation:
- [SmolVLA](./smolvla)
- [Pi0.5](./pi05)
- [GR00T N1.5](./groot)
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
| Dataset | Description | Frames |
| :-------------------------------------------------------------------------------------------------------- | :------------------------- | :----- |
| [Arena-GR1-Manipulation-Task](https://huggingface.co/datasets/nvidia/Arena-GR1-Manipulation-Task-v3) | GR1 microwave manipulation | ~4K |
| [Arena-G1-Loco-Manipulation-Task](https://huggingface.co/datasets/nvidia/Arena-G1-Loco-Manipulation-Task) | G1 loco-manipulation | ~4K |
## Environment Configuration
### Full Configuration Options
```python
from lerobot.envs.configs import IsaaclabArenaEnv
config = IsaaclabArenaEnv(
# Environment selection
environment="gr1_microwave", # Task environment
embodiment="gr1_pink", # Robot embodiment
object="power_drill", # Object to manipulate
# Simulation settings
episode_length=300, # Max steps per episode
headless=True, # Run without GUI
device="cuda:0", # GPU device
seed=42, # Random seed
# Observation configuration
state_keys="robot_joint_pos", # State observation keys (comma-separated)
camera_keys="robot_pov_cam_rgb", # Camera observation keys (comma-separated)
state_dim=54, # Expected state dimension
action_dim=36, # Expected action dimension
camera_height=512, # Camera image height
camera_width=512, # Camera image width
enable_cameras=True, # Enable camera observations
# Video recording
video=False, # Enable video recording
video_length=100, # Frames per video
video_interval=200, # Steps between recordings
# Advanced
mimic=False, # Enable mimic mode
teleop_device=None, # Teleoperation device
disable_fabric=False, # Disable fabric optimization
enable_pinocchio=True, # Enable Pinocchio for IK
)
```
### Using Environment Hub directly for advanced usage
Create a file called `test_env_load_arena.py` or [download from the EnvHub](https://huggingface.co/nvidia/isaaclab-arena-envs/blob/main/tests/test_env_load_arena.py):
```python
import logging
from dataclasses import asdict
from pprint import pformat
import torch
import tqdm
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
@parser.wrap()
def main(cfg: EvalPipelineConfig):
"""Run random action rollout for IsaacLab Arena environment."""
logging.info(pformat(asdict(cfg)))
from lerobot.envs.factory import make_env
env_dict = make_env(
cfg.env,
n_envs=cfg.env.num_envs,
trust_remote_code=True,
)
env = next(iter(env_dict.values()))[0]
env.reset()
for _ in tqdm.tqdm(range(cfg.env.episode_length)):
with torch.inference_mode():
actions = env.action_space.sample()
obs, rewards, terminated, truncated, info = env.step(actions)
if terminated.any() or truncated.any():
obs, info = env.reset()
env.close()
if __name__ == "__main__":
main()
```
Run with:
```bash
python test_env_load_arena.py \
--env.environment=g1_locomanip_pnp \
--env.embodiment=gr1_pink \
--env.object=cracker_box \
--env.num_envs=4 \
--env.enable_cameras=true \
--env.seed=1000 \
--env.video=true \
--env.video_length=10 \
--env.video_interval=15 \
--env.headless=false \
--env.hub_path=nvidia/isaaclab-arena-envs \
--env.type=isaaclab_arena
```
## Creating New Environments
First create a new IsaacLab Arena environment by following the [IsaacLab Arena Documentation](https://isaac-sim.github.io/IsaacLab-Arena/release/0.1.1/index.html).
Clone our EnvHub repo:
```bash
git clone https://huggingface.co/nvidia/isaaclab-arena-envs
```
Modify the `example_envs.yaml` file based on your new environment.
[Upload](./envhub#step-3-upload-to-the-hub) your modified repo to HuggingFace EnvHub.
<Tip>
Your IsaacLab Arena environment code must be locally available during
evaluation. Users can clone your environment repository separately, or you can
bundle the environment code and assets directly in your EnvHub repo.
</Tip>
Then, when evaluating, use your new environment:
```bash
lerobot-eval \
--env.hub_path=<your-env-hub-path>/isaaclab-arena-envs \
--env.environment=<your new environment> \
...other flags...
```
We look forward to your contributions!
## Troubleshooting
### CUDA out of memory
Reduce `batch_size` or use a GPU with more VRAM:
```bash
--eval.batch_size=1
```
### EULA not accepted
Set environment variables before running:
```bash
export ACCEPT_EULA=Y
export PRIVACY_CONSENT=Y
```
### Video recording not working
Enable cameras when running headless:
```bash
--env.video=true --env.enable_cameras=true --env.headless=true
```
### Policy output dimension mismatch
Ensure `action_dim` matches your policy:
```bash
--env.action_dim=36
```
### libGLU.so.1 Errors during Isaac Sim initialization
Ensure you have the following dependencies installed, this is likely to happen on headless machines.
```bash
sudo apt update && sudo apt install -y libglu1-mesa libxt6
```
## See Also
- [EnvHub Documentation](./envhub.mdx) - General EnvHub usage
- [IsaacLab Arena GitHub](https://github.com/isaac-sim/IsaacLab-Arena)
- [IsaacLab Documentation](https://isaac-sim.github.io/IsaacLab/)
## Lightwheel LW-BenchHub
[Lightwheel](https://www.lightwheel.ai) is bringing `Lightwheel-Libero-Tasks` and `Lightwheel-RoboCasa-Tasks` with 268 tasks to the LeRobot ecosystem.
LW-BenchHub collects and generates large-scale datasets via teleoperation that comply with the LeRobot specification, enabling out-of-the-box training and evaluation workflows.
With the unified interface provided by EnvHub, developers can quickly build end-to-end experimental pipelines.
### Install
Assuming you followed the [Installation](#installation) steps, you can install LW-BenchHub with:
```bash
conda install pinocchio -c conda-forge -y
pip install numpy==1.26.0 # revert numpy to version 1.26
sudo apt-get install git-lfs && git lfs install
git clone https://github.com/LightwheelAI/lw_benchhub
git lfs pull # Ensure LFS files (e.g., .usd assets) are downloaded
cd lw_benchhub
pip install -e .
```
For more detailed instructions, please refer to the [LW-BenchHub Documentation](https://docs.lightwheel.net/lw_benchhub/usage/Installation).
### Lightwheel Tasks Dataset
LW-BenchHub datasets are available on HuggingFace Hub:
| Dataset | Description | Tasks | Frames |
| :------------------------------------------------------------------------------------------------------------ | :---------------------- | :---- | :----- |
| [Lightwheel-Tasks-X7S](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-X7S) | X7S LIBERO and RoboCasa | 117 | ~10.3M |
| [Lightwheel-Tasks-Double-Piper](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-Double-Piper) | Double-Piper LIBERO | 130 | ~6.0M |
| [Lightwheel-Tasks-G1-Controller](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-G1-Controller) | G1-Controller LIBERO | 62 | ~2.7M |
| [Lightwheel-Tasks-G1-WBC](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-G1-WBC) | G1-WBC RoboCasa | 32 | ~1.5M |
For training policies, refer to the [Training Policies](#training-policies) section.
### Evaluating Policies
#### Pre-trained Policies
The following trained policies are available:
| Policy | Architecture | Task | Layout | Robot | Link |
| :----------------------- | :----------- | :----------------------------- | :--------- | :-------------- | :------------------------------------------------------------------------------------ |
| smolvla-double-piper-pnp | SmolVLA | L90K1PutTheBlackBowlOnThePlate | libero-1-1 | DoublePiper-Abs | [HuggingFace](https://huggingface.co/LightwheelAI/smolvla-double-piper-pnp/tree/main) |
#### Evaluate SmolVLA
```bash
lerobot-eval \
--policy.path=LightwheelAI/smolvla-double-piper-pnp \
--env.type=isaaclab_arena \
--rename_map='{"observation.images.left_hand_camera_rgb": "observation.images.left_hand", "observation.images.right_hand_camera_rgb": "observation.images.right_hand", "observation.images.first_person_camera_rgb": "observation.images.first_person"}' \
--env.hub_path=LightwheelAI/lw_benchhub_env \
--env.kwargs='{"config_path": "configs/envhub/example.yml"}' \
--trust_remote_code=true \
--env.state_keys=joint_pos \
--env.action_dim=12 \
--env.camera_keys=left_hand_camera_rgb,right_hand_camera_rgb,first_person_camera_rgb \
--policy.device=cuda \
--eval.batch_size=10 \
--eval.n_episodes=100
```
### Environment Configuration
Evaluation can be quickly launched by modifying the `robot`, `task`, and `layout` settings in the configuration file.
#### Full Configuration Options
```yml
# =========================
# Basic Settings
# =========================
disable_fabric: false
device: cuda:0
sensitivity: 1.0
step_hz: 50
enable_cameras: true
execute_mode: eval
episode_length_s: 20.0 # Episode length in seconds, increase if episodes timeout during eval
# =========================
# Robot Settings
# =========================
robot: DoublePiper-Abs # Robot type, DoublePiper-Abs, X7S-Abs, G1-Controller or G1-Controller-DecoupledWBC
robot_scale: 1.0
# =========================
# Task & Scene Settings
# =========================
task: L90K1PutTheBlackBowlOnThePlate # Task name
scene_backend: robocasa
task_backend: robocasa
debug_assets: null
layout: libero-1-1 # Layout and style ID
sources:
- objaverse
- lightwheel
- aigen_objs
object_projects: []
usd_simplify: false
seed: 42
# =========================
# Object Placement Retry Settings
# =========================
max_scene_retry: 4
max_object_placement_retry: 3
resample_objects_placement_on_reset: true
resample_robot_placement_on_reset: true
# =========================
# Replay Configuration Settings
# =========================
replay_cfgs:
add_camera_to_observation: true
render_resolution: [640, 480]
```
### See Also
- [LW-BenchHub GitHub](https://github.com/LightwheelAI/LW-BenchHub)
- [LW-BenchHub Documentation](https://docs.lightwheel.net/lw_benchhub/)

View File

@@ -137,8 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
make_teleoperator_from_config, make_teleoperator_from_config,
so_leader, so101_leader,
bi_so_leader,
) )
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
@@ -197,7 +196,7 @@ def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int):
obs, info = env.reset() obs, info = env.reset()
dt_s = time.perf_counter() - loop_start dt_s = time.perf_counter() - loop_start
precise_sleep(max(1 / fps - dt_s, 0.0)) precise_sleep(1 / fps - dt_s)
loop_s = time.perf_counter() - loop_start loop_s = time.perf_counter() - loop_start
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
@@ -223,7 +222,7 @@ def teleoperate(cfg: TeleoperateConfig):
def main(): def main():
teleoperate(TeleoperateConfig( teleoperate(TeleoperateConfig(
teleop=so_leader.SO101LeaderConfig( teleop=so101_leader.SO101LeaderConfig(
port="/dev/ttyACM0", port="/dev/ttyACM0",
id='leader', id='leader',
use_degrees=False, use_degrees=False,

View File

@@ -12,12 +12,6 @@ Developers and researchers can post-train GR00T N1.5 with their own real or synt
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception. GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
alt="An overview of GR00T"
width="80%"
/>
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes: Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
- Real captured data from robots. - Real captured data from robots.
@@ -109,7 +103,7 @@ Once you have trained your model using your parameters you can run inference in
```bash ```bash
lerobot-record \ lerobot-record \
--robot.type=bi_so_follower \ --robot.type=bi_so100_follower \
--robot.left_arm_port=/dev/ttyACM1 \ --robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \ --robot.right_arm_port=/dev/ttyACM0 \
--robot.id=bimanual_follower \ --robot.id=bimanual_follower \

View File

@@ -58,8 +58,8 @@ lerobot-teleoperate \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader
from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
robot_config = SO101FollowerConfig( robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem58760431541", port="/dev/tty.usbmodem58760431541",
@@ -195,9 +195,9 @@ lerobot-record \
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig
from lerobot.teleoperators.so_leader.config_so100_leader import SO100LeaderConfig from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so_leader.so100_leader import SO100Leader from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun
@@ -408,8 +408,8 @@ lerobot-replay \
import time import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.so100_follower import SO100Follower from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
@@ -432,7 +432,7 @@ for idx in range(dataset.num_frames):
} }
robot.send_action(action) robot.send_action(action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
robot.disconnect() robot.disconnect()
``` ```
@@ -531,8 +531,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.factory import make_pre_post_processors
from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.so100_follower import SO100Follower from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say

View File

@@ -18,7 +18,7 @@ If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus inter
- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/dynamixel.py) for controlling Dynamixel servos - [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/dynamixel.py) for controlling Dynamixel servos
Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/motors_bus.py) abstract class to learn about its API. Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/motors_bus.py) abstract class to learn about its API.
For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so_follower/so101_follower/so101_follower.py) For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so101_follower/so101_follower.py)
Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial): Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial):

View File

@@ -204,7 +204,7 @@ lerobot-calibrate \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.teleoperators.so_leader import SO100LeaderConfig, SO100Leader from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
config = SO100LeaderConfig( config = SO100LeaderConfig(
port="/dev/tty.usbmodem58760431551", port="/dev/tty.usbmodem58760431551",

View File

@@ -1,62 +0,0 @@
# Parameter efficient fine-tuning with 🤗 PEFT
[🤗 PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fine-Tuning) is a library for efficiently adapting
large pretrained models such as pre-trained policies (e.g., SmolVLA, π₀, ...) to new tasks without training all
of the model's parameters while yielding comparable performance.
Install the `lerobot[peft]` optional package to enable PEFT support.
To read about all the possible methods of adaption, please refer to the [🤗 PEFT docs](https://huggingface.co/docs/peft/index).
## Training SmolVLA
In this section we'll show you how to train a pre-trained SmolVLA policy with PEFT on the libero dataset.
For brevity we're only training on the `libero_spatial` subset. We will use `lerobot/smolvla_base` as the model
to parameter efficiently fine-tune:
```
lerobot-train \
--policy.path=lerobot/smolvla_base \
--policy.repo_id=your_hub_name/my_libero_smolvla \
--dataset.repo_id=HuggingFaceVLA/libero \
--policy.output_features=null \
--policy.input_features=null \
--policy.optimizer_lr=1e-3 \
--policy.scheduler_decay_lr=1e-4 \
--env.type=libero \
--env.task=libero_spatial \
--steps=100000 \
--batch_size=32 \
--peft.method_type=LORA \
--peft.r=64
```
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
the closer you get to full fine-tuning
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
if you want to see a specific PEFT method supported.
By default, PEFT will target the `q_proj` and `v_proj` layers of the LM expert in SmolVLA. It will also target the
state and action projection matrices as they are most likely task-dependent. If you need to target different layers
you can use `--peft.target_modules` to specify which layers to target. You can refer to the respective PEFT method's
documentation to see what inputs are supported, (e.g., [LoRA's target_modules documentation](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules)).
Usually a list of suffixes or a regex are supported. For example, to target the MLPs of the `lm_expert` instead of
the `q` and `v` projections, use:
```
--peft.target_modules='(model\.vlm_with_expert\.lm_expert\..*\.(down|gate|up)_proj|.*\.(state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out))'
```
In case you need to fully fine-tune a layer instead of just adapting it, you can supply a list of layer suffixes
to the `--peft.full_training_modules` parameter:
```
--peft.full_training_modules=["state_proj"]
```
The learning rate and the scheduled target learning rate can usually be scaled by a factor of 10 compared to the
learning rate used for full fine-tuning (e.g., 1e-4 normal, so 1e-3 using LoRA).

View File

@@ -44,7 +44,7 @@ Modify the examples to use `PhoneOS.IOS` or `PhoneOS.ANDROID` in `PhoneConfig`.
Teleoperation example: Teleoperation example:
```python ```36:43:examples/phone_so100_teleop.py
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
@@ -103,7 +103,7 @@ Additionally you can customize mapping or safety limits by editing the processor
- Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame. - Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame.
```python ```examples/phone_to_so100/teleoperate.py
kinematics_solver = RobotKinematics( kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf", urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link", target_frame_name="gripper_frame_link",
@@ -114,7 +114,7 @@ Additionally you can customize mapping or safety limits by editing the processor
- The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs. - The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs.
```python ```src/lerobot/teleoperators/phone/phone_processor.py
action["enabled"] = enabled action["enabled"] = enabled
action["target_x"] = -pos[1] if enabled else 0.0 action["target_x"] = -pos[1] if enabled else 0.0
action["target_y"] = pos[0] if enabled else 0.0 action["target_y"] = pos[0] if enabled else 0.0
@@ -127,7 +127,7 @@ Additionally you can customize mapping or safety limits by editing the processor
- The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed. - The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed.
```python ```examples/phone_to_so100/teleoperate.py
EEReferenceAndDelta( EEReferenceAndDelta(
kinematics=kinematics_solver, kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
@@ -138,7 +138,7 @@ Additionally you can customize mapping or safety limits by editing the processor
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits. - The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits.
```python ```examples/phone_to_so100/teleoperate.py
EEBoundsAndSafety( EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10, max_ee_step_m=0.10,
@@ -147,7 +147,7 @@ Additionally you can customize mapping or safety limits by editing the processor
- The `GripperVelocityToJoint` step turns a velocitylike gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied. - The `GripperVelocityToJoint` step turns a velocitylike gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied.
```python ```examples/phone_to_so100/teleoperate.py
GripperVelocityToJoint(speed_factor=20.0) GripperVelocityToJoint(speed_factor=20.0)
``` ```
@@ -157,7 +157,7 @@ We use different IK initial guesses in the kinematic steps. As initial guess eit
- Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame. - Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame.
```python ```examples/phone_to_so100/record.py
InverseKinematicsEEToJoints( InverseKinematicsEEToJoints(
kinematics=kinematics_solver, kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()), motor_names=list(robot.bus.motors.keys()),
@@ -167,7 +167,7 @@ We use different IK initial guesses in the kinematic steps. As initial guess eit
- Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback. - Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback.
```python ```examples/phone_to_so100/replay.py
InverseKinematicsEEToJoints( InverseKinematicsEEToJoints(
kinematics=kinematics_solver, kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()), motor_names=list(robot.bus.motors.keys()),

View File

@@ -6,12 +6,6 @@
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks. π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pi0%20(1).png"
alt="An overview of Pi0"
width="85%"
/>
### The Vision for Physical Intelligence ### The Vision for Physical Intelligence
As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models. As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
@@ -70,8 +64,6 @@ python src/lerobot/scripts/lerobot_train.py \
--policy.compile_model=true \ --policy.compile_model=true \
--policy.gradient_checkpointing=true \ --policy.gradient_checkpointing=true \
--policy.dtype=bfloat16 \ --policy.dtype=bfloat16 \
--policy.freeze_vision_encoder=false \
--policy.train_expert_only=false \
--steps=3000 \ --steps=3000 \
--policy.device=cuda \ --policy.device=cuda \
--batch_size=32 --batch_size=32
@@ -87,15 +79,6 @@ python src/lerobot/scripts/lerobot_train.py \
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base) - [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset) - [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
### Training Parameters Explained
| Parameter | Default | Description |
| ----------------------- | ------- | ------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
## License ## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).

View File

@@ -67,8 +67,6 @@ python src/lerobot/scripts/lerobot_train.py\
--policy.gradient_checkpointing=true \ --policy.gradient_checkpointing=true \
--wandb.enable=true \ --wandb.enable=true \
--policy.dtype=bfloat16 \ --policy.dtype=bfloat16 \
--policy.freeze_vision_encoder=false \
--policy.train_expert_only=false \
--steps=3000 \ --steps=3000 \
--policy.device=cuda \ --policy.device=cuda \
--batch_size=32 --batch_size=32
@@ -84,15 +82,6 @@ python src/lerobot/scripts/lerobot_train.py\
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base) - [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset) - [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
### Training Parameters Explained
| Parameter | Default | Description |
| ----------------------- | ------- | ------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder |
| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters |
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
If your dataset is not converted with `quantiles`, you can convert it with the following command: If your dataset is not converted with `quantiles`, you can convert it with the following command:
```bash ```bash

View File

@@ -1,246 +0,0 @@
# π₀-FAST (Pi0-FAST)
π₀-FAST is a **Vision-Language-Action model for general robot control** that uses autoregressive next-token prediction to model continuous robot actions.
## Model Overview
π₀-FAST combines the power of Vision-Language Models with a novel action tokenization approach called **FAST (Frequency-space Action Sequence Tokenization)**. This enables training autoregressive VLAs on highly dexterous tasks that are impossible with standard binning-based discretization, while training **up to 5x faster** than diffusion-based approaches like π₀.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pifast.png"
alt="An overview of Pi0-FAST"
width="85%"
/>
### Why FAST?
Standard approaches for robot action tokenization use simple per-dimension, per-timestep binning schemes. While passable for simple behaviors, this rapidly breaks down for complex and dexterous skills that require precision and high-frequency control.
FAST solves this by compressing action sequences using signal processing techniques, resulting in a dense sequence of action tokens that can be predicted autoregressively—just like language tokens.
### How FAST Tokenization Works
The FAST tokenizer compresses action sequences through the following steps:
1. **Normalize**: Take a continuous action chunk of shape `(H, D)` where `H` is the horizon and `D` is the action dimension. Normalize using one of the supported normalization methods (Quantiles recommended to handle outliers).
2. **Discrete Cosine Transform (DCT)**: Apply DCT (via scipy) to each action dimension separately. DCT is a compression algorithm commonly used in image and audio codecs (JPEG, MP3).
3. **Quantization**: Round and remove insignificant coefficients for each action dimension, producing a sparse frequency matrix.
4. **Flatten**: Flatten the matrix into a 1D vector, with low-frequency components first.
5. **Byte Pair Encoding (BPE)**: Train a BPE tokenizer to compress the DCT coefficients into dense action tokens, typically achieving **10x compression** over prior tokenization approaches.
This approach can transform **any existing VLM** into a VLA by training it to predict these FAST tokens.
## Installation Requirements
1. Install LeRobot by following our [Installation Guide](./installation).
2. Install π₀-FAST dependencies by running:
```bash
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training a Custom FAST Tokenizer
You have two options for the FAST tokenizer:
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
### Training Your Own Tokenizer
```bash
lerobot-train-tokenizer \
--repo_id "user/my-lerobot-dataset" \
--action_horizon 10 \
--encoded_dims "0:6" \
--vocab_size 1024 \
--scale 10.0 \
--normalization_mode QUANTILES \
--output_dir "./my_fast_tokenizer" \
--push_to_hub \
--hub_repo_id "username/my-action-tokenizer"
```
### Key Tokenizer Parameters
| Parameter | Description | Default |
| ---------------------- | --------------------------------------------------------------------------------- | ------------ |
| `--repo_id` | LeRobot dataset repository ID | Required |
| `--action_horizon` | Number of future actions in each chunk | `10` |
| `--encoded_dims` | Comma-separated dimension ranges to encode (e.g., `"0:6,7:23"`) | `"0:6,7:23"` |
| `--vocab_size` | BPE vocabulary size | `1024` |
| `--scale` | DCT scaling factor for quantization | `10.0` |
| `--normalization_mode` | Normalization mode (`MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`, `IDENTITY`) | `QUANTILES` |
| `--sample_fraction` | Fraction of chunks to sample per episode | `0.1` |
## Usage
To use π₀-FAST in LeRobot, specify the policy type as:
```python
policy.type=pi0_fast
```
## Training
For training π₀-FAST, you can use the LeRobot training script:
```bash
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi0_fast \
--output_dir=./outputs/pi0fast_training \
--job_name=pi0fast_training \
--policy.pretrained_path=lerobot/pi0_fast_base \
--policy.dtype=bfloat16 \
--policy.gradient_checkpointing=true \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.max_action_tokens=256 \
--steps=100000 \
--batch_size=4 \
--policy.device=cuda
```
### Key Training Parameters
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
## Inference
### KV-Caching for Fast Inference
π₀-FAST supports **KV-caching**, a widely used optimization in LLM inference. This caches the key-value pairs from the attention mechanism, avoiding redundant computation during autoregressive decoding.
```python
# KV-caching is enabled by default
policy.use_kv_cache=true
```
### Inference Example
```python
from lerobot.policies.pi0_fast import PI0FastPolicy, PI0FastConfig
# Load the policy
policy = PI0FastPolicy.from_pretrained("your-model-path")
# During inference
actions = policy.predict_action_chunk(batch)
```
## Model Architecture
π₀-FAST uses a PaliGemma-based architecture:
- **Vision Encoder**: SigLIP vision tower for image understanding
- **Language Model**: Gemma 2B for processing language instructions and predicting action tokens
The model takes images, text instructions, and robot state as input, and outputs discrete FAST tokens that are decoded back to continuous actions.
## Configuration Options
| Parameter | Description | Default |
| -------------------- | ----------------------------------------------- | ---------- |
| `paligemma_variant` | VLM backbone variant (`gemma_300m`, `gemma_2b`) | `gemma_2b` |
| `max_state_dim` | Maximum state vector dimension (padded) | `32` |
| `max_action_dim` | Maximum action vector dimension (padded) | `32` |
| `temperature` | Sampling temperature (0.0 for greedy) | `0.0` |
| `max_decoding_steps` | Maximum decoding steps | `256` |
| `use_kv_cache` | Enable KV caching for faster inference | `true` |
## Comparison with π₀
| Feature | π₀ | π₀-FAST |
| --------------------- | ------------------------- | ---------------------------- |
| Action Representation | Flow Matching (Diffusion) | Autoregressive Tokens (FAST) |
| Training Speed | 1x | **5x faster** |
| Dexterity | High | High |
| Inference Method | Iterative Denoising | Autoregressive Decoding |
| KV-Caching | N/A | Supported |
## Reproducing π₀Fast results
We reproduce the results of π₀Fast on the LIBERO benchmark using the LeRobot implementation. We take the LeRobot PiFast base model [lerobot/pi0fast-base](https://huggingface.co/lerobot/pi0fast-base) and finetune for an additional 40kk steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
The finetuned model can be found here:
- **π₀Fast LIBERO**: [lerobot/pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)
With the following training command:
```bash
lerobot-train \
--dataset.repo_id=lerobot/libero \
--output_dir=outputs/libero_pi0fast \
--job_name=libero_pi0fast \
--policy.path=lerobot/pi0fast_base \
--policy.dtype=bfloat16 \
--steps=100000 \
--save_freq=20000 \
--batch_size=4 \
--policy.device=cuda \
--policy.scheduler_warmup_steps=4000 \
--policy.scheduler_decay_steps=100000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=true \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.max_action_tokens=256 \
--policy.empty_cameras=1 \
```
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
```bash
tasks="libero_object,libero_spatial,libero_goal,libero_10"
lerobot-eval \
--policy.path=lerobot/pi0fast-libero \
--policy.max_action_tokens=256 \
--env.type=libero \
--policy.gradient_checkpointing=false \
--env.task=${tasks} \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--rename_map='{"observation.images.image":"observation.images.base_0_rgb","observation.images.image2":"observation.images.left_wrist_0_rgb"}'
```
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
### Results
We obtain the following results on the LIBERO benchmark:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| ----------- | -------------- | ------------- | ----------- | --------- | -------- |
| **π₀-fast** | 70.0 | 100.0 | 100.0 | 60.0 | **82.5** |
The full evaluation output folder, including videos, is available [here](https://drive.google.com/drive/folders/1HXpwPTRm4hx6g1sF2P7OOqGG0TwPU7LQ?usp=sharing)
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
## References
- [FAST: Efficient Robot Action Tokenization](https://www.physicalintelligence.company/research/fast) - Physical Intelligence Blog
- [OpenPI Repository](https://github.com/Physical-Intelligence/openpi) - Original implementation
- [FAST Tokenizer on Hugging Face](https://huggingface.co/physical-intelligence/fast) - Pre-trained tokenizer

View File

@@ -1,30 +1,20 @@
# WALL-OSS # WALL-OSS
This repository contains the Hugging Face port of [**WALL-OSS**](https://x2robot.com/en/research/68bc2cde8497d7f238dde690), a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
--- ---
## Model Overview ## Model Overview
| Feature | Description | | Feature | Description |
| ------------------ | ----------------------------------------------------- | | ------------------ | ----------------------------------------------------- | --- |
| Base Model | Qwen2.5-VL (Vision-Language Model) | | Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | | Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | | Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | | Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
--- ---
## Additional Resources
Paper: https://arxiv.org/pdf/2509.11766
Official Repository: https://github.com/X-Square-Robot/wall-x
Hugging Face: https://huggingface.co/x-square-robot
---
## Citation ## Citation
If you use this work, please cite: If you use this work, please cite:
@@ -42,4 +32,4 @@ If you use this work, please cite:
## License ## License
This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x). This port follows the **Apache 2.0 License**.

View File

@@ -30,7 +30,7 @@ Each of these pipelines handle different conversions between different action an
Below is an example of the three pipelines that we use in the phone to SO-100 follower examples: Below is an example of the three pipelines that we use in the phone to SO-100 follower examples:
```python ```69:90:examples/phone_so100_record.py
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # teleop -> dataset action phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # teleop -> dataset action
steps=[ steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os), MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
@@ -84,7 +84,7 @@ Dataset features are determined by the keys saved in the dataset. Each step can
Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples: Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples:
```python ```src/lerobot/robots/so100_follower/robot_kinematic_processor.py
def transform_features( def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
@@ -103,7 +103,7 @@ Here we declare what PolicyFeatures we modify in this step, so we know what feat
Below is an example of how we aggregate and merge features in the phone to SO-100 record example: Below is an example of how we aggregate and merge features in the phone to SO-100 record example:
```python ```121:145:examples/phone_so100_record.py
features=combine_feature_dicts( features=combine_feature_dicts(
# Run the feature contract of the pipelines # Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps # This tells you how the features would look like after the pipeline steps

View File

@@ -38,7 +38,6 @@ docker run --rm -it \
start_rviz:=true start_sdk_server:=true mujoco:=true start_rviz:=true start_sdk_server:=true mujoco:=true
``` ```
> [!NOTE]
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance: > If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
> >
> ``` > ```
@@ -142,7 +141,7 @@ If you choose this option but still want to use the VR teleoperation application
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command: First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
```bash ```bash
lerobot-record \ python -m lerobot.record \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--robot.id=r2-0000 \ --robot.id=r2-0000 \
@@ -151,7 +150,6 @@ lerobot-record \
--teleop.type=reachy2_teleoperator \ --teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \ --teleop.ip_address=192.168.0.200 \
--teleop.with_mobile_base=false \ --teleop.with_mobile_base=false \
--robot.with_torso_camera=true \
--dataset.repo_id=pollen_robotics/record_test \ --dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \ --dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \ --dataset.num_episodes=1 \
@@ -167,7 +165,7 @@ lerobot-record \
**Extended setup overview (all options included):** **Extended setup overview (all options included):**
```bash ```bash
lerobot-record \ python -m lerobot.record \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--robot.use_external_commands=true \ --robot.use_external_commands=true \
@@ -179,8 +177,6 @@ lerobot-record \
--robot.with_left_teleop_camera=true \ --robot.with_left_teleop_camera=true \
--robot.with_right_teleop_camera=true \ --robot.with_right_teleop_camera=true \
--robot.with_torso_camera=false \ --robot.with_torso_camera=false \
--robot.camera_width=640 \
--robot.camera_height=480 \
--robot.disable_torque_on_disconnect=false \ --robot.disable_torque_on_disconnect=false \
--robot.max_relative_target=5.0 \ --robot.max_relative_target=5.0 \
--teleop.type=reachy2_teleoperator \ --teleop.type=reachy2_teleoperator \
@@ -216,10 +212,9 @@ Must be set to true if a compliant Reachy 2 is used to control another one.
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies. From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
To avoid this, you can exclude specific parts from recording and replay using: To avoid this, you can exclude specific parts from recording and replay using:
```bash ````
--robot.with_<part>=false --robot.with_<part>=false
``` ```,
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
It determine whether the corresponding part is recorded in the observations. True if not set. It determine whether the corresponding part is recorded in the observations. True if not set.
@@ -227,60 +222,49 @@ By default, **all parts are recorded**.
The same per-part mechanism is available in `reachy2_teleoperator` as well. The same per-part mechanism is available in `reachy2_teleoperator` as well.
```bash ````
--teleop.with\_<part>
```
--teleop.with\_<part>
```
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
Determine whether the corresponding part is recorded in the actions. True if not set. Determine whether the corresponding part is recorded in the actions. True if not set.
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator. > **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
> For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`. For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
##### Use the relevant cameras ##### Use the relevant cameras
You can do the same for **cameras**. Enable or disable each camera with default parameters using: You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
```bash ```
--robot.with_left_teleop_camera=<true|false> \
--robot.with_right_teleop_camera=<true|false> \ --robot.with_left_teleop_camera=<true|false>
--robot.with_right_teleop_camera=<true|false>
--robot.with_torso_camera=<true|false> --robot.with_torso_camera=<true|false>
```
By default, no camera is recorded, all camera arguments are set to `false`. ````
If you want to, you can use custom `width` and `height` parameters for Reachy 2's cameras using the `--robot.camera_width` & `--robot.camera_height` argument:
```bash
--robot.camera_width=1920 \
--robot.camera_height=1080
```
This will change the resolution of all 3 default robot cameras (enabled by the above bool arguments).
If you want, you can add additional cameras other than the ones in the robot as usual with:
```bash
--robot.cameras="{ extra: {type: opencv, index_or_path: 42, width: 640, height: 480, fps: 30}}" \
```
## Step 2: Replay ## Step 2: Replay
Make sure the robot is configured with the same parts as the dataset: Make sure the robot is configured with the same parts as the dataset:
```bash ```bash
lerobot-replay \ python -m lerobot.replay \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--robot.use_external_commands=false \ --robot.use_external_commands=false \
--robot.with_mobile_base=false \ --robot.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \ --dataset.repo_id=pollen_robotics/record_test \
--dataset.episode=0 --dataset.episode=0
``` --display_data=true
````
## Step 3: Train ## Step 3: Train
```bash ```bash
lerobot-train \ python -m lerobot.scripts.train \
--dataset.repo_id=pollen_robotics/record_test \ --dataset.repo_id=pollen_robotics/record_test \
--policy.type=act \ --policy.type=act \
--output_dir=outputs/train/reachy2_test \ --output_dir=outputs/train/reachy2_test \
@@ -293,9 +277,10 @@ lerobot-train \
## Step 4: Evaluate ## Step 4: Evaluate
```bash ```bash
lerobot-eval \ python -m lerobot.record \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--display_data=false \
--dataset.repo_id=pollen_robotics/eval_record_test \ --dataset.repo_id=pollen_robotics/eval_record_test \
--dataset.single_task="Evaluate reachy2 policy" \ --dataset.single_task="Evaluate reachy2 policy" \
--dataset.num_episodes=10 \ --dataset.num_episodes=10 \

View File

@@ -4,12 +4,6 @@ SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework fo
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358) **Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-sarm.png"
alt="An overview of SARM"
width="80%"
/>
## Why Reward Models? ## Why Reward Models?
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement. Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.

View File

@@ -103,7 +103,7 @@ lerobot-setup-motors \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig
config = SO100FollowerConfig( config = SO100FollowerConfig(
port="/dev/tty.usbmodem585A0076841", port="/dev/tty.usbmodem585A0076841",
@@ -177,7 +177,7 @@ lerobot-setup-motors \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
config = SO100LeaderConfig( config = SO100LeaderConfig(
port="/dev/tty.usbmodem585A0076841", port="/dev/tty.usbmodem585A0076841",
@@ -579,7 +579,7 @@ lerobot-calibrate \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.robots.so_follower import SO100FollowerConfig, SO100Follower from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower
config = SO100FollowerConfig( config = SO100FollowerConfig(
port="/dev/tty.usbmodem585A0076891", port="/dev/tty.usbmodem585A0076891",
@@ -617,7 +617,7 @@ lerobot-calibrate \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.teleoperators.so_leader import SO100LeaderConfig, SO100Leader from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
config = SO100LeaderConfig( config = SO100LeaderConfig(
port="/dev/tty.usbmodem58760431551", port="/dev/tty.usbmodem58760431551",

View File

@@ -125,7 +125,7 @@ lerobot-setup-motors \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
config = SO101FollowerConfig( config = SO101FollowerConfig(
port="/dev/tty.usbmodem585A0076841", port="/dev/tty.usbmodem585A0076841",
@@ -201,7 +201,7 @@ lerobot-setup-motors \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig
config = SO101LeaderConfig( config = SO101LeaderConfig(
port="/dev/tty.usbmodem585A0076841", port="/dev/tty.usbmodem585A0076841",
@@ -364,7 +364,7 @@ lerobot-calibrate \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
config = SO101FollowerConfig( config = SO101FollowerConfig(
port="/dev/tty.usbmodem585A0076891", port="/dev/tty.usbmodem585A0076891",
@@ -413,7 +413,7 @@ lerobot-calibrate \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader
config = SO101LeaderConfig( config = SO101LeaderConfig(
port="/dev/tty.usbmodem58760431551", port="/dev/tty.usbmodem58760431551",

View File

@@ -1,21 +1,21 @@
# Unitree G1 # Unitree G1 Robot Setup and Control
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion. This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
## About ## About the Unitree G1
We support both 29 and 23 DOF G1 EDU version. We introduce: We offer support for both 29 and 23 DOF G1. We introduce:
- **`unitree g1` robot class, handling low level read/write from/to the humanoid** - **`unitree g1` robot class, handling low level communication with the humanoid**
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot - **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma - **GR00T locomotion policy** for bipedal walking and balance
- **Simulation mode** for testing policies without the physical robot in mujoco - **MuJoCo simulation mode** for testing policies without the physical robot
--- ---
## Connection guide ## Part 1: Connect to Robot over Ethernet
### Step 1: Configure Ethernet Interface ### Step 1: Configure Your Computer's Ethernet Interface
Set a static IP on the same subnet as the robot: Set a static IP on the same subnet as the robot:
@@ -26,7 +26,7 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
sudo ip link set enp131s0 up sudo ip link set enp131s0 up
``` ```
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164. **Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
### Step 2: SSH into the Robot ### Step 2: SSH into the Robot
@@ -35,24 +35,25 @@ ssh unitree@192.168.123.164
# Password: 123 # Password: 123
``` ```
You should now be connected to the G1's Orin. You should now be connected to the robot's onboard computer.
--- ---
## Part 2: Enable WiFi on the Robot ## Part 2: Enable WiFi on the Robot
Wlan0 is disabled by default on the G1. To enable it: Once connected via Ethernet, follow these steps to enable WiFi:
### Step 1: Enable WiFi Hardware ### Step 1: Enable WiFi Hardware
```bash ```bash
# Unblock WiFi radio
sudo rfkill unblock wifi sudo rfkill unblock wifi
sudo rfkill unblock all sudo rfkill unblock all
# Bring up wlan0 # Bring up WiFi interface
sudo ip link set wlan0 up sudo ip link set wlan0 up
# Enable NetworkManager control of wlan0 # Enable NetworkManager control
sudo nmcli radio wifi on sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager sudo systemctl restart NetworkManager
@@ -72,7 +73,7 @@ sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTA
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
``` ```
**On the G1:** **On the robot:**
```bash ```bash
# Add laptop as default gateway # Add laptop as default gateway
@@ -110,7 +111,7 @@ ssh unitree@<YOUR_ROBOT_IP>
# Password: 123 # Password: 123
``` ```
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address. Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
--- ---
@@ -146,9 +147,9 @@ python src/lerobot/robots/unitree_g1/run_g1_server.py
--- ---
## Part 4: Controlling the robot ## Part 4: Running GR00T Locomotion
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy With the robot server running, you can now control the robot from your laptop.
### Step 1: Install LeRobot on your machine ### Step 1: Install LeRobot on your machine
@@ -171,30 +172,34 @@ Edit the config file to match your robot's WiFi IP:
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP. robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
``` ```
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
### Step 3: Run the Locomotion Policy ### Step 3: Run the Locomotion Policy
```bash ```bash
# Run GR00T locomotion controller # Run GR00T locomotion controller
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1" python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
# Run Holosoma locomotion controller
python examples/unitree_g1/holosoma_locomotion.py
``` ```
### Step 4: Control with Remote
- **Left stick**: Forward/backward and left/right movement
- **Right stick**: Rotation
- **R1 button**: Raise waist height
- **R2 button**: Lower waist height
Press `Ctrl+C` to stop the policy. Press `Ctrl+C` to stop the policy.
--- ---
## Running in Simulation Mode (MuJoCo) ## Extra: Running in Simulation Mode (MuJoCo)
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
## Additional Resources ## Additional Resources
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python) - [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl) - [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
- [Holosoma](https://github.com/amazon-far/holosoma)
- [LeRobot Documentation](https://github.com/huggingface/lerobot) - [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot) - [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)

View File

@@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh
# Local-only: Save to a custom output directory (no hub push) # Local-only: Save to a custom output directory (no hub push)
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \ --operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video --operation.output_dir /path/to/output/pusht_video
# Save with new repo_id (local storage) # Save with new repo_id (local storage)
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \ --new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video --operation.type convert_to_video
# Convert and push to Hugging Face Hub # Convert and push to Hugging Face Hub
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \ --new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video \ --operation.type convert_to_video \
--push_to_hub true --push_to_hub true
# Convert with custom video codec and quality settings # Convert with custom video codec and quality settings
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \ --operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \ --operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \ --operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \ --operation.pix_fmt yuv420p \
@@ -124,23 +124,16 @@ lerobot-edit-dataset \
# Convert only specific episodes # Convert only specific episodes
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \ --operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \ --operation.output_dir outputs/pusht_video \
--operation.episode_indices "[0, 1, 2, 5, 10]" --operation.episode_indices "[0, 1, 2, 5, 10]"
# Convert with multiple workers for parallel processing # Convert with multiple workers for parallel processing
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \ --operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \ --operation.output_dir outputs/pusht_video \
--operation.num_workers 8 --operation.num_workers 8
# For memory-constrained systems, users can now specify limits:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.max_episodes_per_batch 50 \
--operation.max_frames_per_batch 10000
``` ```
**Parameters:** **Parameters:**

View File

@@ -8,12 +8,6 @@ X Square Robots WALL-OSS is now integrated into Hugging Faces LeRobot ecos
The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabled—one that understands not just how the world works, but how to act effectively within it. The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabled—one that understands not just how the world works, but how to act effectively within it.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/walloss-lerobot-paper.png"
alt="An overview of WALL-OSS"
width="85%"
/>
Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration → Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include: Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration → Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include:
- **Embodied perceptionenhanced multimodal pretraining**: Large-scale training on unified visionlanguageaction data to strengthen spatial, causal, and manipulation understanding. - **Embodied perceptionenhanced multimodal pretraining**: Large-scale training on unified visionlanguageaction data to strengthen spatial, causal, and manipulation understanding.

View File

@@ -41,7 +41,8 @@ from lerobot.robots import ( # noqa: F401
RobotConfig, RobotConfig,
koch_follower, koch_follower,
make_robot_from_config, make_robot_from_config,
so_follower, so100_follower,
so101_follower,
) )
from lerobot.utils.constants import ACTION from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
@@ -96,7 +97,7 @@ def replay(cfg: ReplayConfig):
robot.send_action(action) robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t dt_s = time.perf_counter() - start_episode_t
precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) precise_sleep(1 / dataset.fps - dt_s)
robot.disconnect() robot.disconnect()

View File

@@ -21,7 +21,7 @@ from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say

View File

@@ -18,7 +18,7 @@ import time
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data from lerobot.utils.visualization_utils import init_rerun, log_rerun_data

View File

@@ -34,11 +34,12 @@ from lerobot.processor.converters import (
transition_to_observation, transition_to_observation,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say

View File

@@ -26,14 +26,15 @@ from lerobot.processor.converters import (
transition_to_observation, transition_to_observation,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
EEBoundsAndSafety, EEBoundsAndSafety,
EEReferenceAndDelta, EEReferenceAndDelta,
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
GripperVelocityToJoint, GripperVelocityToJoint,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction

View File

@@ -23,10 +23,11 @@ from lerobot.processor.converters import (
robot_action_observation_to_transition, robot_action_observation_to_transition,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
@@ -95,7 +96,7 @@ def main():
# Send action to robot # Send action to robot
_ = robot.send_action(joint_action) _ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
# Clean up # Clean up
robot.disconnect() robot.disconnect()

View File

@@ -21,13 +21,14 @@ from lerobot.processor.converters import (
robot_action_observation_to_transition, robot_action_observation_to_transition,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
EEBoundsAndSafety, EEBoundsAndSafety,
EEReferenceAndDelta, EEReferenceAndDelta,
GripperVelocityToJoint, GripperVelocityToJoint,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone from lerobot.teleoperators.phone.teleop_phone import Phone

View File

@@ -94,9 +94,9 @@ from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
bi_so_follower,
koch_follower, koch_follower,
so_follower, so100_follower,
so101_follower,
) )
from lerobot.robots.utils import make_robot_from_config from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES from lerobot.utils.constants import OBS_IMAGES
@@ -455,18 +455,7 @@ def demo_cli(cfg: RTCDemoConfig):
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0": if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile config.compile_model = cfg.use_torch_compile
if config.use_peft: policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
from peft import PeftConfig, PeftModel
peft_pretrained_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC # Turn on RTC
policy.config.rtc_config = cfg.rtc policy.config.rtc_config = cfg.rtc

View File

@@ -34,11 +34,12 @@ from lerobot.processor.converters import (
transition_to_observation, transition_to_observation,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say

View File

@@ -27,14 +27,16 @@ from lerobot.processor.converters import (
transition_to_observation, transition_to_observation,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
EEBoundsAndSafety, EEBoundsAndSafety,
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun

View File

@@ -24,10 +24,11 @@ from lerobot.processor.converters import (
robot_action_observation_to_transition, robot_action_observation_to_transition,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
@@ -96,7 +97,7 @@ def main():
# Send action to robot # Send action to robot
_ = robot.send_action(joint_action) _ = robot.send_action(joint_action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
# Clean up # Clean up
robot.disconnect() robot.disconnect()

View File

@@ -23,13 +23,15 @@ from lerobot.processor.converters import (
robot_action_to_transition, robot_action_to_transition,
transition_to_robot_action, transition_to_robot_action,
) )
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
EEBoundsAndSafety, EEBoundsAndSafety,
ForwardKinematicsJointsToEE, ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints, InverseKinematicsEEToJoints,
) )
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data from lerobot.utils.visualization_utils import init_rerun, log_rerun_data

View File

@@ -5,7 +5,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5 MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20 MAX_STEPS_PER_EPISODE = 20

View File

@@ -4,7 +4,7 @@ from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.helpers import visualize_action_queue_size from lerobot.async_inference.helpers import visualize_action_queue_size
from lerobot.async_inference.robot_client import RobotClient from lerobot.async_inference.robot_client import RobotClient
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.robots.so_follower import SO100FollowerConfig from lerobot.robots.so100_follower import SO100FollowerConfig
def main(): def main():
@@ -30,7 +30,6 @@ def main():
robot=robot_cfg, robot=robot_cfg,
server_address=server_address, server_address=server_address,
policy_device="mps", policy_device="mps",
client_device="cpu",
policy_type="act", policy_type="act",
pretrained_name_or_path="<user>/robot_learning_tutorial_act", pretrained_name_or_path="<user>/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g chunk_size_threshold=0.5, # g

View File

@@ -5,7 +5,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5 MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20 MAX_STEPS_PER_EPISODE = 20

View File

@@ -5,7 +5,8 @@ from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5 MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20 MAX_STEPS_PER_EPISODE = 20

View File

@@ -14,8 +14,8 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.teleoperators.so_leader import SO100LeaderConfig from lerobot.teleoperators.so100_leader import SO100LeaderConfig
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
LOG_EVERY = 10 LOG_EVERY = 10

View File

@@ -5,7 +5,8 @@ from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5 MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20 MAX_STEPS_PER_EPISODE = 20

View File

@@ -13,9 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Example: GR00T Locomotion with Pre-loaded Policies
This example demonstrates the NEW pattern for loading GR00T policies externally
and passing them to the robot class.
"""
import argparse import argparse
import logging import logging
import threading
import time import time
from collections import deque from collections import deque
@@ -24,26 +31,24 @@ import onnxruntime as ort
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32) GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
MISSING_JOINTS = [] MISSING_JOINTS = []
G1_MODEL = "g1_23" # Or "g1_29" G1_MODEL = "g1_23" # or "g1_29"
if G1_MODEL == "g1_23": if G1_MODEL == "g1_23":
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
LOCOMOTION_ACTION_SCALE = 0.25
LOCOMOTION_CONTROL_DT = 0.02
# Control parameters
ACTION_SCALE = 0.25
CONTROL_DT = 0.02 # 50Hz
ANG_VEL_SCALE: float = 0.25 ANG_VEL_SCALE: float = 0.25
DOF_POS_SCALE: float = 1.0 DOF_POS_SCALE: float = 1.0
DOF_VEL_SCALE: float = 0.05 DOF_VEL_SCALE: float = 0.05
@@ -56,12 +61,12 @@ DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
def load_groot_policies( def load_groot_policies(
repo_id: str = DEFAULT_GROOT_REPO_ID, repo_id: str = DEFAULT_GROOT_REPO_ID,
) -> tuple[ort.InferenceSession, ort.InferenceSession]: ) -> tuple[ort.InferenceSession, ort.InferenceSession]:
"""Load GR00T dual-policy system (Balance + Walk) from the hub. """Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
Args: Args:
repo_id: Hugging Face Hub repository ID containing the ONNX policies. repo_id: Hugging Face Hub repository ID containing the ONNX policies.
""" """
logger.info(f"Loading GR00T dual-policy system from the hub ({repo_id})...") logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
# Download ONNX policies from Hugging Face Hub # Download ONNX policies from Hugging Face Hub
balance_path = hf_hub_download( balance_path = hf_hub_download(
@@ -83,7 +88,15 @@ def load_groot_policies(
class GrootLocomotionController: class GrootLocomotionController:
"""GR00T lower-body locomotion controller for the Unitree G1.""" """
Handles GR00T-style locomotion control for the Unitree G1 robot.
This controller manages:
- Dual-policy system (Balance + Walk)
- 29-joint observation processing
- 15D action output (legs + waist)
- Policy inference and motor command generation
"""
def __init__(self, policy_balance, policy_walk, robot, config): def __init__(self, policy_balance, policy_walk, robot, config):
self.policy_balance = policy_balance self.policy_balance = policy_balance
@@ -91,9 +104,9 @@ class GrootLocomotionController:
self.robot = robot self.robot = robot
self.config = config self.config = config
self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
# Robot state # GR00T-specific state
self.groot_qj_all = np.zeros(29, dtype=np.float32) self.groot_qj_all = np.zeros(29, dtype=np.float32)
self.groot_dqj_all = np.zeros(29, dtype=np.float32) self.groot_dqj_all = np.zeros(29, dtype=np.float32)
self.groot_action = np.zeros(15, dtype=np.float32) self.groot_action = np.zeros(15, dtype=np.float32)
@@ -103,39 +116,47 @@ class GrootLocomotionController:
self.groot_height_cmd = 0.74 # Default base height self.groot_height_cmd = 0.74 # Default base height
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# Input to GR00T is 6 frames (6*86D=516) # input to gr00t is 6 frames (6*86D=516)
for _ in range(6): for _ in range(6):
self.groot_obs_history.append(np.zeros(86, dtype=np.float32)) self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
# Thread management
self.locomotion_running = False
self.locomotion_thread = None
logger.info("GrootLocomotionController initialized") logger.info("GrootLocomotionController initialized")
def run_step(self): def groot_locomotion_run(self):
# Get current observation # get current observation
obs = self.robot.get_observation() robot_state = self.robot.get_observation()
if not obs: if robot_state is None:
return return
# Get command from remote controller # get command from remote controller
if obs["remote.buttons"][0]: # R1 - raise waist if robot_state.wireless_remote is not None:
self.groot_height_cmd += 0.001 self.robot.remote_controller.set(robot_state.wireless_remote)
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) if self.robot.remote_controller.button[0]: # R1 - raise waist
if obs["remote.buttons"][4]: # R2 - lower waist self.groot_height_cmd += 0.001
self.groot_height_cmd -= 0.001 self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) if self.robot.remote_controller.button[4]: # R2 - lower waist
self.groot_height_cmd -= 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
self.cmd[0] = obs["remote.ly"] # Forward/backward self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
self.cmd[1] = obs["remote.lx"] * -1 # Left/right self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
# Get joint positions and velocities from flat dict for i in range(29):
for motor in G1_29_JointIndex: self.groot_qj_all[i] = robot_state.motor_state[i].q
name = motor.name self.groot_dqj_all[i] = robot_state.motor_state[i].dq
idx = motor.value
self.groot_qj_all[idx] = obs[f"{name}.q"]
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
# Adapt observation for g1_23dof # adapt observation for g1_23dof
for idx in MISSING_JOINTS: for idx in MISSING_JOINTS:
self.groot_qj_all[idx] = 0.0 self.groot_qj_all[idx] = 0.0
self.groot_dqj_all[idx] = 0.0 self.groot_dqj_all[idx] = 0.0
@@ -144,18 +165,18 @@ class GrootLocomotionController:
qj_obs = self.groot_qj_all.copy() qj_obs = self.groot_qj_all.copy()
dqj_obs = self.groot_dqj_all.copy() dqj_obs = self.groot_dqj_all.copy()
# Express IMU data in gravity frame of reference # express imu data in gravity frame of reference
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]] quat = robot_state.imu_state.quaternion
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32) ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat) gravity_orientation = self.robot.get_gravity_orientation(quat)
# Scale joint positions and velocities before policy inference # scale joint positions and velocities before policy inference
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
dqj_obs = dqj_obs * DOF_VEL_SCALE dqj_obs = dqj_obs * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# Build single frame observation # build single frame observation
self.groot_obs_single[:3] = self.cmd * np.array(CMD_SCALE) self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
self.groot_obs_single[3] = self.groot_height_cmd self.groot_obs_single[3] = self.groot_height_cmd
self.groot_obs_single[4:7] = self.groot_orientation_cmd self.groot_obs_single[4:7] = self.groot_orientation_cmd
self.groot_obs_single[7:10] = ang_vel_scaled self.groot_obs_single[7:10] = ang_vel_scaled
@@ -173,76 +194,113 @@ class GrootLocomotionController:
end_idx = start_idx + 86 end_idx = start_idx + 86
self.groot_obs_stacked[start_idx:end_idx] = obs_frame self.groot_obs_stacked[start_idx:end_idx] = obs_frame
cmd_magnitude = np.linalg.norm(self.cmd) # Run policy inference (ONNX) with 516D stacked observation
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
selected_policy = ( selected_policy = (
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
) # Balance/standing policy for small commands, walking policy for movement commands ) # balance/standing policy for small commands, walking policy for movement commands
# Run policy inference # run policy inference
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)} ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
ort_outs = selected_policy.run(None, ort_inputs) ort_outs = selected_policy.run(None, ort_inputs)
self.groot_action = ort_outs[0].squeeze() self.groot_action = ort_outs[0].squeeze()
# Transform action back to target joint positions # transform action back to target joint positions
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
# Build action dict (only first 15 joints for GR00T) # command motors
action_dict = {}
for i in range(15): for i in range(15):
motor_name = G1_29_JointIndex(i).name motor_idx = i
action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i]) self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Zero out missing joints for g1_23dof # adapt action for g1_23dof
for joint_idx in MISSING_JOINTS: for joint_idx in MISSING_JOINTS:
motor_name = G1_29_JointIndex(joint_idx).name self.robot.msg.motor_cmd[joint_idx].q = 0.0
action_dict[f"{motor_name}.q"] = 0.0 self.robot.msg.motor_cmd[joint_idx].qd = 0
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
self.robot.msg.motor_cmd[joint_idx].tau = 0
# Send action to robot # send action to robot
self.robot.send_action(action_dict) self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None: """Background thread that runs the locomotion policy at specified rate."""
"""Main function to run the GR00T locomotion controller. logger.info("Locomotion thread started")
while self.locomotion_running:
Args:
repo_id: Hugging Face Hub repository ID for GR00T policies.
"""
# Load policies
policy_balance, policy_walk = load_groot_policies(repo_id=repo_id)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
robot.connect()
# Initialize gr00T locomotion controller
groot_controller = GrootLocomotionController(
policy_balance=policy_balance,
policy_walk=policy_walk,
robot=robot,
config=config,
)
try:
robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES)
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist")
logger.info("Press Ctrl+C to stop")
# Run step
while not robot._shutdown_event.is_set():
start_time = time.time() start_time = time.time()
groot_controller.run_step() try:
self.groot_locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
# Sleep to maintain control rate
elapsed = time.time() - start_time elapsed = time.time() - start_time
sleep_time = max(0, CONTROL_DT - elapsed) sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time) time.sleep(sleep_time)
except KeyboardInterrupt: logger.info("Locomotion thread stopped")
logger.info("Stopping locomotion...")
finally: def start_locomotion_thread(self):
if robot.is_connected: if self.locomotion_running:
robot.disconnect() logger.warning("Locomotion thread already running")
logger.info("Done!") return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
total_time = 3.0
num_step = int(total_time / self.robot.control_dt)
# Only control legs, not arms (first 12 joints)
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
dof_size = len(default_pos)
# Get current lowstate
robot_state = self.robot.get_observation()
# Record the current leg positions
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
for i in range(dof_size):
init_dof_pos[i] = robot_state.motor_state[i].q
# Move legs to default pos
for i in range(num_step):
alpha = i / num_step
for motor_idx in range(dof_size):
target_pos = default_pos[motor_idx]
self.robot.msg.motor_cmd[motor_idx].q = (
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
)
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Reached default position (legs only)")
if __name__ == "__main__": if __name__ == "__main__":
@@ -255,4 +313,35 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
run(repo_id=args.repo_id) # load policies
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
# initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# initialize gr00t locomotion controller
groot_controller = GrootLocomotionController(
policy_balance=policy_balance,
policy_walk=policy_walk,
robot=robot,
config=config,
)
# reset legs and start locomotion thread
try:
groot_controller.reset_robot()
groot_controller.start_locomotion_thread()
# log status
logger.info("Robot initialized with GR00T locomotion policies")
logger.info("Locomotion controller running in background thread")
logger.info("Press Ctrl+C to stop")
# keep robot alive
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
groot_controller.stop_locomotion_thread()
print("Done!")

View File

@@ -1,264 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import time
import numpy as np
import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
DEFAULT_ANGLES[[0, 6]] = -0.312 # Hip pitch
DEFAULT_ANGLES[[3, 9]] = 0.669 # Knee
DEFAULT_ANGLES[[4, 10]] = -0.363 # Ankle pitch
DEFAULT_ANGLES[[15, 22]] = 0.2 # Shoulder pitch
DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll
DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll
DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow
MISSING_JOINTS = []
G1_MODEL = "g1_23" # Or "g1_29"
if G1_MODEL == "g1_23":
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
# Control parameters
ACTION_SCALE = 0.25
CONTROL_DT = 0.02 # 50Hz
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
GAIT_PERIOD = 1.0
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
# Policy filename mapping
POLICY_FILES = {
"fastsac": "fastsac_g1_29dof.onnx",
"ppo": "ppo_g1_29dof.onnx",
}
def load_policy(
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
policy_type: str = "fastsac",
) -> tuple[ort.InferenceSession, np.ndarray, np.ndarray]:
"""Load Holosoma locomotion policy and extract KP/KD from metadata.
Args:
repo_id: Hugging Face Hub repo ID
policy_type: Either "fastsac" (default) or "ppo"
Returns:
(policy, kp, kd) tuple
"""
if policy_type not in POLICY_FILES:
raise ValueError(f"Unknown policy type: {policy_type}. Choose from: {list(POLICY_FILES.keys())}")
filename = POLICY_FILES[policy_type]
logger.info(f"Loading {policy_type.upper()} policy from: {repo_id}/{filename}")
policy_path = hf_hub_download(repo_id=repo_id, filename=filename)
policy = ort.InferenceSession(policy_path)
logger.info(f"Policy loaded: {policy.get_inputs()[0].shape}{policy.get_outputs()[0].shape}")
# Extract KP/KD from ONNX metadata
model = onnx.load(policy_path)
metadata = {prop.key: prop.value for prop in model.metadata_props}
if "kp" not in metadata or "kd" not in metadata:
raise ValueError("ONNX model must contain 'kp' and 'kd' in metadata")
kp = np.array(json.loads(metadata["kp"]), dtype=np.float32)
kd = np.array(json.loads(metadata["kd"]), dtype=np.float32)
logger.info(f"Loaded KP/KD from ONNX ({len(kp)} joints)")
return policy, kp, kd
class HolosomaLocomotionController:
"""Holosoma whole-body locomotion controller for Unitree G1."""
def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray):
self.policy = policy
self.robot = robot
# Override robot's PD gains with policy gains
self.robot.kp = kp
self.robot.kd = kd
self.cmd = np.zeros(3, dtype=np.float32)
# Robot state
self.qj = np.zeros(29, dtype=np.float32)
self.dqj = np.zeros(29, dtype=np.float32)
self.obs = np.zeros(100, dtype=np.float32)
self.last_action = np.zeros(29, dtype=np.float32)
# Gait phase
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD)
self.is_standing = True
def run_step(self):
# Get current observation
obs = self.robot.get_observation()
if not obs:
return
# Get command from remote controller
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
self.cmd[:] = [ly, -lx, -rx]
# Get joint positions and velocities
for motor in G1_29_JointIndex:
name = motor.name
idx = motor.value
self.qj[idx] = obs[f"{name}.q"]
self.dqj[idx] = obs[f"{name}.dq"]
# Adapt observation for g1_23dof
for idx in MISSING_JOINTS:
self.qj[idx] = 0.0
self.dqj[idx] = 0.0
# Express IMU data in gravity frame of reference
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
gravity = self.robot.get_gravity_orientation(quat)
# Scale joint positions and velocities before policy inference
qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE
dqj_obs = self.dqj * DOF_VEL_SCALE
ang_vel_s = ang_vel * ANG_VEL_SCALE
# Update gait phase
if np.linalg.norm(self.cmd[:2]) < 0.01 and abs(self.cmd[2]) < 0.01:
self.phase[0, :] = np.pi
self.is_standing = True
elif self.is_standing:
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = False
else:
self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2 * np.pi) - np.pi
sin_ph = np.sin(self.phase[0])
cos_ph = np.cos(self.phase[0])
# Build observations
self.obs[0:29] = self.last_action
self.obs[29:32] = ang_vel_s
self.obs[32] = self.cmd[2]
self.obs[33:35] = self.cmd[:2]
self.obs[35:37] = cos_ph
self.obs[37:66] = qj_obs
self.obs[66:95] = dqj_obs
self.obs[95:98] = gravity
self.obs[98:100] = sin_ph
# Run policy inference
ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)}
raw_action = self.policy.run(None, ort_in)[0].squeeze()
action = np.clip(raw_action, -100.0, 100.0)
self.last_action = action.copy()
# Transform action back to target joint positions
target = DEFAULT_ANGLES + action * ACTION_SCALE
# Build action dict
action_dict = {}
for motor in G1_29_JointIndex:
action_dict[f"{motor.name}.q"] = float(target[motor.value])
# Zero out missing joints for g1_23dof
for joint_idx in MISSING_JOINTS:
motor_name = G1_29_JointIndex(joint_idx).name
action_dict[f"{motor_name}.q"] = 0.0
# Send action to robot
self.robot.send_action(action_dict)
def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None:
"""Main function to run the Holosoma locomotion controller.
Args:
repo_id: Hugging Face Hub repository ID for Holosoma policies.
policy_type: Policy type to use ('fastsac' or 'ppo').
"""
# Load policy and gains
policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
robot.connect()
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
try:
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate")
logger.info("Press Ctrl+C to stop")
# Run step
while not robot._shutdown_event.is_set():
start_time = time.time()
holosoma_controller.run_step()
elapsed = time.time() - start_time
sleep_time = max(0, CONTROL_DT - elapsed)
time.sleep(sleep_time)
except KeyboardInterrupt:
logger.info("Stopping locomotion...")
finally:
if robot.is_connected:
robot.disconnect()
logger.info("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_HOLOSOMA_REPO_ID,
help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})",
)
parser.add_argument(
"--policy",
type=str,
choices=["fastsac", "ppo"],
default="fastsac",
help="Policy type to use: 'fastsac' (default) or 'ppo'",
)
args = parser.parse_args()
run(repo_id=args.repo_id, policy_type=args.policy)

View File

@@ -25,9 +25,9 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project] [project]
name = "lerobot" name = "lerobot"
version = "0.4.4" version = "0.4.3"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"] readme = "README.md"
license = { text = "Apache-2.0" } license = { text = "Apache-2.0" }
requires-python = ">=3.10" requires-python = ">=3.10"
authors = [ authors = [
@@ -74,7 +74,7 @@ dependencies = [
"packaging>=24.2,<26.0", "packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0", "pynput>=1.7.7,<1.9.0",
"pyserial>=3.5,<4.0", "pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0", "wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
@@ -97,7 +97,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1,<2.7.0"] pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"] placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.57.1,<5.0.0"] transformers-dep = ["transformers>=4.57.1,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
# Motors # Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -109,9 +109,9 @@ hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [ unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0", "pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0" "onnxruntime>=1.16.0"
] ]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
kinematics = ["lerobot[placo-dep]"] kinematics = ["lerobot[placo-dep]"]
intelrealsense = [ intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'", "pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
@@ -127,7 +127,7 @@ wallx = [
"torchdiffeq==0.2.5", "torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11" "qwen_vl_utils==0.0.11"
] ]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"] pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [ groot = [
"lerobot[transformers-dep]", "lerobot[transformers-dep]",
@@ -140,13 +140,12 @@ groot = [
"ninja>=1.11.1,<2.0.0", "ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
] ]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"] sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
xvla = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features # Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
# Development # Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -183,8 +182,7 @@ all = [
"lerobot[phone]", "lerobot[phone]",
"lerobot[libero]", "lerobot[libero]",
"lerobot[metaworld]", "lerobot[metaworld]",
"lerobot[sarm]", "lerobot[sarm]"
"lerobot[peft]",
] ]
[project.scripts] [project.scripts]
@@ -197,7 +195,6 @@ lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
lerobot-eval="lerobot.scripts.lerobot_eval:main" lerobot-eval="lerobot.scripts.lerobot_eval:main"
lerobot-train="lerobot.scripts.lerobot_train:main" lerobot-train="lerobot.scripts.lerobot_train:main"
lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main"
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
@@ -420,10 +417,6 @@ conflicts = [
{ extra = "wallx" }, { extra = "wallx" },
{ extra = "libero" }, { extra = "libero" },
], ],
[
{ extra = "wallx" },
{ extra = "peft" },
],
[ [
{ extra = "wallx" }, { extra = "wallx" },
{ extra = "all" }, { extra = "all" },
@@ -457,10 +450,6 @@ conflicts = [
{ extra = "pi" }, { extra = "pi" },
{ extra = "libero" }, { extra = "libero" },
], ],
[
{ extra = "pi" },
{ extra = "peft" },
],
[ [
{ extra = "pi" }, { extra = "pi" },
{ extra = "all" }, { extra = "all" },

View File

@@ -1,72 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup
def get_version_from_toml() -> str:
"""Return the project's version string parsed from `pyproject.toml`.
The function scans `pyproject.toml` line-by-line looking for a line
that starts with ``version`` (for example: ``version = "1.2.3"``)
and returns the value without surrounding quotes. If no such line is
found a :class:`ValueError` is raised.
Returns:
The version string from `pyproject.toml` (e.g. ``"1.2.3"`` ->
``1.2.3``).
"""
version = None
with open("pyproject.toml", encoding="utf-8") as f:
for line in f:
if line.strip().startswith("version"):
version = line.split("=")[1].strip().strip('"')
break
if version is None:
raise ValueError("Version not found in pyproject.toml")
return version
def read_long_description() -> str:
"""Read and return the project's long description for setup.
This function reads `README.md` and replaces image links that point
to the local `./media/` directory with absolute raw GitHub URLs that
reference the release tag corresponding to the version parsed from
`pyproject.toml` (for example, ``v1.2.3``). The modified README
content is returned as a string suitable for passing to
``setuptools.setup(long_description=...)``.
Returns:
The README content with rewritten media links.
"""
with open("README.md", encoding="utf-8") as f:
content = f.read()
version = get_version_from_toml()
git_tag = f"v{version}"
base_raw_url = f"https://raw.githubusercontent.com/huggingface/lerobot/{git_tag}/"
content = content.replace('src="./media/', f'src="{base_raw_url}media/')
return content
setup(
long_description=read_long_description(),
long_description_content_type="text/markdown",
)

View File

@@ -126,12 +126,6 @@ class RobotClientConfig:
# Device configuration # Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"}) policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
client_device: str = field(
default="cpu",
metadata={
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
},
)
# Control behavior configuration # Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"}) chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
@@ -167,9 +161,6 @@ class RobotClientConfig:
if not self.policy_device: if not self.policy_device:
raise ValueError("policy_device cannot be empty") raise ValueError("policy_device cannot be empty")
if not self.client_device:
raise ValueError("client_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1: if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}") raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
@@ -193,7 +184,6 @@ class RobotClientConfig:
"policy_type": self.policy_type, "policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path, "pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device, "policy_device": self.policy_device,
"client_device": self.client_device,
"chunk_size_threshold": self.chunk_size_threshold, "chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps, "fps": self.fps,
"actions_per_chunk": self.actions_per_chunk, "actions_per_chunk": self.actions_per_chunk,

View File

@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
DEFAULT_OBS_QUEUE_TIMEOUT = 2 DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies # All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"] SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots # TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"] SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]

View File

@@ -18,7 +18,6 @@ import os
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
@@ -40,8 +39,8 @@ from lerobot.utils.utils import init_logging
Action = torch.Tensor Action = torch.Tensor
# observation as received from the robot (can be numpy arrays, floats, etc.) # observation as received from the robot
RawObservation = dict[str, Any] RawObservation = dict[str, torch.Tensor]
# observation as those recorded in LeRobot dataset (keys are different) # observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor] LeRobotObservation = dict[str, torch.Tensor]

View File

@@ -381,8 +381,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0) action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}") self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
action_tensor = action_tensor.detach().cpu()
"""5. Convert to TimedAction list""" """5. Convert to TimedAction list"""
action_chunk = self._time_action_chunk( action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()

View File

@@ -25,7 +25,6 @@ python src/lerobot/async_inference/robot_client.py \
--policy_type=act \ --policy_type=act \
--pretrained_name_or_path=user/model \ --pretrained_name_or_path=user/model \
--policy_device=mps \ --policy_device=mps \
--client_device=cpu \
--actions_per_chunk=50 \ --actions_per_chunk=50 \
--chunk_size_threshold=0.5 \ --chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \ --aggregate_fn_name=weighted_average \
@@ -52,11 +51,12 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
bi_so_follower, bi_so100_follower,
koch_follower, koch_follower,
make_robot_from_config, make_robot_from_config,
omx_follower, omx_follower,
so_follower, so100_follower,
so101_follower,
) )
from lerobot.transport import ( from lerobot.transport import (
services_pb2, # type: ignore services_pb2, # type: ignore
@@ -286,21 +286,6 @@ class RobotClient:
timed_actions = pickle.loads(actions_chunk.data) # nosec timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start deserialize_time = time.perf_counter() - deserialize_start
# Log device type of received actions
if len(timed_actions) > 0:
received_device = timed_actions[0].get_action().device.type
self.logger.debug(f"Received actions on device: {received_device}")
# Move actions to client_device (e.g., for downstream planners that need GPU)
client_device = self.config.client_device
if client_device != "cpu":
for timed_action in timed_actions:
if timed_action.get_action().device.type != client_device:
timed_action.action = timed_action.get_action().to(client_device)
self.logger.debug(f"Converted actions to device: {client_device}")
else:
self.logger.debug(f"Actions kept on device: {client_device}")
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions)) self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations # Calculate network latency if we have matching observations

View File

@@ -35,19 +35,18 @@ class Reachy2CameraConfig(CameraConfig):
name="teleop", name="teleop",
image_type="left", image_type="left",
ip_address="192.168.0.200", # IP address of the robot ip_address="192.168.0.200", # IP address of the robot
port=50065, # Port of the camera server fps=15,
width=640, width=640,
height=480, height=480,
fps=30, # Not configurable for Reachy 2 cameras
color_mode=ColorMode.RGB, color_mode=ColorMode.RGB,
) # Left teleop camera, 640x480 @ 30FPS ) # Left teleop camera, 640x480 @ 15FPS
``` ```
Attributes: Attributes:
name: Name of the camera device. Can be "teleop" or "depth". name: Name of the camera device. Can be "teleop" or "depth".
image_type: Type of image stream. For "teleop" camera, can be "left" or "right". image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet) For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
fps: Requested frames per second for the color stream. Not configurable for Reachy 2 cameras. fps: Requested frames per second for the color stream.
width: Requested frame width in pixels for the color stream. width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream. height: Requested frame height in pixels for the color stream.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB. color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
@@ -63,6 +62,7 @@ class Reachy2CameraConfig(CameraConfig):
color_mode: ColorMode = ColorMode.RGB color_mode: ColorMode = ColorMode.RGB
ip_address: str | None = "localhost" ip_address: str | None = "localhost"
port: int = 50065 port: int = 50065
# use_depth: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.name not in ["teleop", "depth"]: if self.name not in ["teleop", "depth"]:

View File

@@ -16,13 +16,12 @@
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager. Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
""" """
from __future__ import annotations
import logging import logging
import os import os
import platform import platform
import time import time
from typing import TYPE_CHECKING, Any from threading import Event, Lock, Thread
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
@@ -31,19 +30,10 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy import numpy as np # type: ignore # TODO: add type stubs for numpy
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
from lerobot.utils.import_utils import _reachy2_sdk_available from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
CameraManager,
if TYPE_CHECKING or _reachy2_sdk_available: )
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
else:
CameraManager = None
class CameraView:
LEFT = 0
RIGHT = 1
from lerobot.utils.errors import DeviceNotConnectedError from lerobot.utils.errors import DeviceNotConnectedError
@@ -79,10 +69,17 @@ class Reachy2Camera(Camera):
self.config = config self.config = config
self.fps = config.fps
self.color_mode = config.color_mode self.color_mode = config.color_mode
self.cam_manager: CameraManager | None = None self.cam_manager: CameraManager | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})" return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
@@ -103,23 +100,44 @@ class Reachy2Camera(Camera):
def connect(self, warmup: bool = True) -> None: def connect(self, warmup: bool = True) -> None:
""" """
Connects to the Reachy2 CameraManager as specified in the configuration. Connects to the Reachy2 CameraManager as specified in the configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
""" """
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port) self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
if self.cam_manager is None:
raise DeviceNotConnectedError(f"Could not connect to {self}.")
self.cam_manager.initialize_cameras() self.cam_manager.initialize_cameras()
logger.info(f"{self} connected.") logger.info(f"{self} connected.")
@staticmethod @staticmethod
def find_cameras() -> list[dict[str, Any]]: def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
""" """
Detection not implemented for Reachy2 cameras. Detects available Reachy 2 cameras.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'name', 'stereo',
and the default profile properties (width, height, fps).
""" """
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.") initialized_cameras = []
camera_manager = CameraManager(host=ip_address, port=port)
for camera in [camera_manager.teleop, camera_manager.depth]:
if camera is None:
continue
height, width, _, _, _, _, _ = camera.get_parameters()
camera_info = {
"name": camera._cam_info.name,
"stereo": camera._cam_info.stereo,
"default_profile": {
"width": width,
"height": height,
"fps": 30,
},
}
initialized_cameras.append(camera_info)
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
""" """
@@ -137,49 +155,95 @@ class Reachy2Camera(Camera):
(height, width, channels), using the specified or default (height, width, channels), using the specified or default
color mode and applying any configured rotation. color mode and applying any configured rotation.
""" """
start_time = time.perf_counter()
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.") raise DeviceNotConnectedError(f"{self} is not connected.")
if self.cam_manager is None: start_time = time.perf_counter()
raise DeviceNotConnectedError(f"{self} is not connected.")
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8) frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): if self.cam_manager is None:
if self.config.image_type == "left": raise DeviceNotConnectedError(f"{self} is not connected.")
frame = self.cam_manager.teleop.get_frame(
CameraView.LEFT, size=(self.config.width, self.config.height)
)[0]
elif self.config.image_type == "right":
frame = self.cam_manager.teleop.get_frame(
CameraView.RIGHT, size=(self.config.width, self.config.height)
)[0]
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.config.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()[0]
elif self.config.image_type == "rgb":
frame = self.cam_manager.depth.get_frame(size=(self.config.width, self.config.height))[0]
else: else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.") if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.config.image_type == "left":
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
elif self.config.image_type == "right":
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.config.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()[0]
elif self.config.image_type == "rgb":
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
if frame is None: if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8) return np.empty((0, 0, 3), dtype=np.uint8)
if self.config.color_mode == "rgb": if self.config.color_mode == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
read_duration_ms = (time.perf_counter() - start_time) * 1e3 read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame return frame
def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
""" """
Reads the latest available frame. Reads the latest available frame asynchronously.
This method retrieves the most recent frame available in Reachy 2's low-level software. This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args: Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -197,10 +261,22 @@ class Reachy2Camera(Camera):
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.") raise DeviceNotConnectedError(f"{self} is not connected.")
frame = self.read() if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None: if frame is None:
raise RuntimeError(f"Internal error: No frame available for {self}.") raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
return frame return frame
@@ -211,9 +287,12 @@ class Reachy2Camera(Camera):
Raises: Raises:
DeviceNotConnectedError: If the camera is already disconnected. DeviceNotConnectedError: If the camera is already disconnected.
""" """
if not self.is_connected: if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.") raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
if self.cam_manager is not None: if self.cam_manager is not None:
self.cam_manager.disconnect() self.cam_manager.disconnect()

View File

@@ -43,11 +43,6 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
cameras[key] = Reachy2Camera(cfg) cameras[key] = Reachy2Camera(cfg)
elif cfg.type == "zmq":
from .zmq.camera_zmq import ZMQCamera
cameras[key] = ZMQCamera(cfg)
else: else:
try: try:
cameras[key] = cast(Camera, make_device_from_device_class(cfg)) cameras[key] = cast(Camera, make_device_from_device_class(cfg))

View File

@@ -1,235 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ZMQCamera - Captures frames from remote cameras via ZeroMQ using JSON protocol in the
following format:
{
"timestamps": {"camera_name": float},
"images": {"camera_name": "<base64-jpeg>"}
}
"""
import base64
import json
import logging
import time
from threading import Event, Lock, Thread
from typing import Any
import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
from .configuration_zmq import ZMQCameraConfig
logger = logging.getLogger(__name__)
class ZMQCamera(Camera):
"""
Example usage:
```python
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
camera = ZMQCamera(config)
camera.connect()
frame = camera.read()
camera.disconnect()
```
"""
def __init__(self, config: ZMQCameraConfig):
super().__init__(config)
import zmq
self.config = config
self.server_address = config.server_address
self.port = config.port
self.camera_name = config.camera_name
self.color_mode = config.color_mode
self.timeout_ms = config.timeout_ms
self.context: zmq.Context | None = None
self.socket: zmq.Socket | None = None
self._connected = False
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
return f"ZMQCamera({self.camera_name}@{self.server_address}:{self.port})"
@property
def is_connected(self) -> bool:
return self._connected and self.context is not None and self.socket is not None
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server."""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
logger.info(f"Connecting to {self}...")
try:
import zmq
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
self.socket.setsockopt(zmq.CONFLATE, True)
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
self._connected = True
# Auto-detect resolution
if self.width is None or self.height is None:
h, w = self.read().shape[:2]
self.height = h
self.width = w
logger.info(f"{self} resolution: {w}x{h}")
logger.info(f"{self} connected.")
if warmup:
time.sleep(0.1)
except Exception as e:
self._cleanup()
raise RuntimeError(f"Failed to connect to {self}: {e}") from e
def _cleanup(self):
"""Clean up ZMQ resources."""
self._connected = False
if self.socket:
self.socket.close()
self.socket = None
if self.context:
self.context.term()
self.context = None
@staticmethod
def find_cameras() -> list[dict[str, Any]]:
"""ZMQ cameras require manual configuration (server address/port)."""
return []
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Read a single frame from the ZMQ camera.
Returns:
np.ndarray: Decoded frame (height, width, 3)
"""
if not self.is_connected or self.socket is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
try:
message = self.socket.recv_string()
except Exception as e:
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
# Decode JSON message
data = json.loads(message)
if "images" not in data:
raise RuntimeError(f"{self} invalid message: missing 'images' key")
images = data["images"]
# Get image by camera name or first available
if self.camera_name in images:
img_b64 = images[self.camera_name]
elif images:
img_b64 = next(iter(images.values()))
else:
raise RuntimeError(f"{self} no images in message")
# Decode base64 JPEG
img_bytes = base64.b64decode(img_b64)
frame = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
if frame is None:
raise RuntimeError(f"{self} failed to decode image")
return frame
def _read_loop(self) -> None:
while self.stop_event and not self.stop_event.is_set():
try:
frame = self.read()
with self.frame_lock:
self.latest_frame = frame
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except TimeoutError:
pass
except Exception as e:
logger.warning(f"Read error: {e}")
def _start_read_thread(self) -> None:
if self.thread and self.thread.is_alive():
return
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, daemon=True)
self.thread.start()
def _stop_read_thread(self) -> None:
if self.stop_event:
self.stop_event.set()
if self.thread and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
"""Read latest frame asynchronously (non-blocking)."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.thread or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"{self} no frame available")
return frame
def disconnect(self) -> None:
"""Disconnect from ZMQ camera."""
if not self.is_connected and not self.thread:
raise DeviceNotConnectedError(f"{self} not connected.")
self._stop_read_thread()
self._cleanup()
logger.info(f"{self} disconnected.")

View File

@@ -1,46 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
__all__ = ["ZMQCameraConfig", "ColorMode"]
@CameraConfig.register_subclass("zmq")
@dataclass
class ZMQCameraConfig(CameraConfig):
server_address: str
port: int = 5555
camera_name: str = "zmq_camera"
color_mode: ColorMode = ColorMode.RGB
timeout_ms: int = 5000
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
if not self.server_address:
raise ValueError("`server_address` cannot be empty.")
if self.port <= 0 or self.port > 65535:
raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.")

View File

@@ -1,114 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Streams camera images over ZMQ.
Uses lerobot's OpenCVCamera for capture, encodes images to base64 and sends them over ZMQ.
"""
import base64
import contextlib
import json
import logging
import time
from collections import deque
import cv2
import numpy as np
import zmq
from lerobot.cameras.configs import ColorMode
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
logger = logging.getLogger(__name__)
def encode_image(image: np.ndarray, quality: int = 80) -> str:
"""Encode RGB image to base64 JPEG string."""
_, buffer = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
return base64.b64encode(buffer).decode("utf-8")
class ImageServer:
def __init__(self, config: dict, port: int = 5555):
self.fps = config.get("fps", 30)
self.cameras: dict[str, OpenCVCamera] = {}
for name, cfg in config.get("cameras", {}).items():
shape = cfg.get("shape", [480, 640])
cam_config = OpenCVCameraConfig(
index_or_path=cfg.get("device_id", 0),
fps=self.fps,
width=shape[1],
height=shape[0],
color_mode=ColorMode.RGB,
)
camera = OpenCVCamera(cam_config)
camera.connect()
self.cameras[name] = camera
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
# ZMQ PUB socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.setsockopt(zmq.SNDHWM, 20)
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.bind(f"tcp://*:{port}")
logger.info(f"ImageServer running on port {port}")
def run(self):
frame_count = 0
frame_times = deque(maxlen=60)
try:
while True:
t0 = time.time()
# Build message
message = {"timestamps": {}, "images": {}}
for name, cam in self.cameras.items():
frame = cam.read() # Returns RGB
message["timestamps"][name] = time.time()
message["images"][name] = encode_image(frame)
# Send as JSON string (suppress if buffer full)
with contextlib.suppress(zmq.Again):
self.socket.send_string(json.dumps(message), zmq.NOBLOCK)
frame_count += 1
frame_times.append(time.time() - t0)
if frame_count % 60 == 0:
logger.debug(f"FPS: {len(frame_times) / sum(frame_times):.1f}")
sleep = (1.0 / self.fps) - (time.time() - t0)
if sleep > 0:
time.sleep(sleep)
except KeyboardInterrupt:
pass
finally:
for cam in self.cameras.values():
cam.disconnect()
self.socket.close()
self.context.term()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
config = {"fps": 30, "cameras": {"head_camera": {"device_id": 4, "shape": [480, 640]}}}
ImageServer(config, port=5555).run()

View File

@@ -67,31 +67,3 @@ class EvalConfig:
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
) )
@dataclass
class PeftConfig:
# PEFT offers many fine-tuning methods, layer adapters being the most common and currently also the most
# effective methods so we'll focus on those in this high-level config interface.
# Either a string (module name suffix or 'all-linear'), a list of module name suffixes or a regular expression
# describing module names to target with the configured PEFT method. Some policies have a default value for this
# so that you don't *have* to choose which layers to adapt but it might still be worthwhile depending on your case.
target_modules: list[str] | str | None = None
# Names/suffixes of modules to fully fine-tune and store alongside adapter weights. Useful for layers that are
# not part of a pre-trained model (e.g., action state projections). Depending on the policy this defaults to layers
# that are newly created in pre-trained policies. If you're fine-tuning an already trained policy you might want
# to set this to `[]`. Corresponds to PEFT's `modules_to_save`.
full_training_modules: list[str] | None = None
# The PEFT (adapter) method to apply to the policy. Needs to be a valid PEFT type.
method_type: str = "LORA"
# Adapter initialization method. Look at the specific PEFT adapter documentation for defaults.
init_type: str | None = None
# We expect that all PEFT adapters are in some way doing rank-decomposition therefore this parameter specifies
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
# fine-tuning.
r: int = 16

View File

@@ -38,8 +38,6 @@ class EvalPipelineConfig:
seed: int | None = 1000 seed: int | None = 1000
# Rename map for the observation to override the image and state keys # Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict) rename_map: dict[str, str] = field(default_factory=dict)
# Explicit consent to execute remote code from the Hub (required for hub environments).
trust_remote_code: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one. # HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@@ -55,18 +55,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
n_obs_steps: int = 1 n_obs_steps: int = 1
# `input_features` can be set to None/null in order to infer those values from the dataset. input_features: dict[str, PolicyFeature] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] | None = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] | None = field(default_factory=dict)
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps" device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used. # automatic gradient scaling is used.
use_amp: bool = False use_amp: bool = False
# Whether the policy employed PEFT for training.
use_peft: bool = False
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
repo_id: str | None = None repo_id: str | None = None
@@ -105,16 +101,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError raise NotImplementedError
@property
def image_observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
"""Return indices for delta image observations only.
Unlike observation_delta_indices which applies to ALL observations,
this only applies to image observations (keys starting with observation.images).
Default returns None. Override in subclass to enable.
"""
return None
@property @property
@abc.abstractmethod @abc.abstractmethod
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
@@ -139,8 +125,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
@property @property
def robot_state_feature(self) -> PolicyFeature | None: def robot_state_feature(self) -> PolicyFeature | None:
if not self.input_features:
return None
for ft_name, ft in self.input_features.items(): for ft_name, ft in self.input_features.items():
if ft.type is FeatureType.STATE and ft_name == OBS_STATE: if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
return ft return ft
@@ -148,8 +132,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
@property @property
def env_state_feature(self) -> PolicyFeature | None: def env_state_feature(self) -> PolicyFeature | None:
if not self.input_features:
return None
for _, ft in self.input_features.items(): for _, ft in self.input_features.items():
if ft.type is FeatureType.ENV: if ft.type is FeatureType.ENV:
return ft return ft
@@ -157,14 +139,10 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
@property @property
def image_features(self) -> dict[str, PolicyFeature]: def image_features(self) -> dict[str, PolicyFeature]:
if not self.input_features:
return {}
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property @property
def action_feature(self) -> PolicyFeature | None: def action_feature(self) -> PolicyFeature | None:
if not self.output_features:
return None
for ft_name, ft in self.output_features.items(): for ft_name, ft in self.output_features.items():
if ft.type is FeatureType.ACTION and ft_name == ACTION: if ft.type is FeatureType.ACTION and ft_name == ACTION:
return ft return ft

View File

@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs from lerobot import envs
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim import OptimizerConfig from lerobot.optim import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.optim.schedulers import LRSchedulerConfig
@@ -65,7 +65,6 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig) eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig) wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
# RA-BC (Reward-Aligned Behavior Cloning) parameters # RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training use_rabc: bool = False # Enable reward-weighted training

View File

@@ -19,7 +19,6 @@ import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
import datasets
import pandas as pd import pandas as pd
import tqdm import tqdm
@@ -33,7 +32,6 @@ from lerobot.datasets.utils import (
DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
get_file_size_in_mb, get_file_size_in_mb,
get_hf_features_from_features,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
to_parquet_with_hf_images, to_parquet_with_hf_images,
update_chunk_file_indices, update_chunk_file_indices,
@@ -404,21 +402,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
} }
unique_chunk_file_ids = sorted(unique_chunk_file_ids) unique_chunk_file_ids = sorted(unique_chunk_file_ids)
contains_images = len(dst_meta.image_keys) > 0
# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
for src_chunk_idx, src_file_idx in unique_chunk_file_ids: for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format( src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx chunk_index=src_chunk_idx, file_index=src_file_idx
) )
if contains_images: df = pd.read_parquet(src_path)
# Use HuggingFace datasets to read source data to preserve image format
src_ds = datasets.Dataset.from_parquet(str(src_path))
df = src_ds.to_pandas()
else:
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta) df = update_data_df(df, src_meta, dst_meta)
data_idx = append_or_create_parquet_file( data_idx = append_or_create_parquet_file(
@@ -428,9 +417,8 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_files_size_in_mb, data_files_size_in_mb,
chunk_size, chunk_size,
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
contains_images=contains_images, contains_images=len(dst_meta.image_keys) > 0,
aggr_root=dst_meta.root, aggr_root=dst_meta.root,
hf_features=hf_features,
) )
return data_idx return data_idx
@@ -500,7 +488,6 @@ def append_or_create_parquet_file(
default_path: str, default_path: str,
contains_images: bool = False, contains_images: bool = False,
aggr_root: Path = None, aggr_root: Path = None,
hf_features: datasets.Features | None = None,
): ):
"""Appends data to an existing parquet file or creates a new one based on size constraints. """Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -516,7 +503,6 @@ def append_or_create_parquet_file(
default_path: Format string for generating file paths. default_path: Format string for generating file paths.
contains_images: Whether the data contains images requiring special handling. contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset. aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
Returns: Returns:
dict: Updated index dictionary with current chunk and file indices. dict: Updated index dictionary with current chunk and file indices.
@@ -526,7 +512,7 @@ def append_or_create_parquet_file(
if not dst_path.exists(): if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images: if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features) to_parquet_with_hf_images(df, dst_path)
else: else:
df.to_parquet(dst_path) df.to_parquet(dst_path)
return idx return idx
@@ -541,17 +527,12 @@ def append_or_create_parquet_file(
final_df = df final_df = df
target_path = new_path target_path = new_path
else: else:
if contains_images: existing_df = pd.read_parquet(dst_path)
# Use HuggingFace datasets to read existing data to preserve image format
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
existing_df = existing_ds.to_pandas()
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True) final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path target_path = dst_path
if contains_images: if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features) to_parquet_with_hf_images(final_df, target_path)
else: else:
final_df.to_parquet(target_path) final_df.to_parquet(target_path)

View File

@@ -26,7 +26,6 @@ This module provides utilities for:
import logging import logging
import shutil import shutil
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path from pathlib import Path
import datasets import datasets
@@ -52,8 +51,7 @@ from lerobot.datasets.utils import (
write_stats, write_stats,
write_tasks, write_tasks,
) )
from lerobot.datasets.video_utils import encode_video_frames, get_video_info from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -1085,561 +1083,3 @@ def _copy_episodes_metadata_and_stats(
else: else:
if src_dataset.meta.stats: if src_dataset.meta.stats:
write_stats(src_dataset.meta.stats, dst_meta.root) write_stats(src_dataset.meta.stats, dst_meta.root)
def _save_episode_images_for_video(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_index: int,
num_workers: int = 4,
) -> None:
"""Save images from a specific episode and camera to disk for video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_index: Index of the episode to save
num_workers: Number of threads for parallel image saving
"""
# Create directory
imgs_dir.mkdir(parents=True, exist_ok=True)
# Get dataset without torch format for PIL image access
hf_dataset = dataset.hf_dataset.with_format(None)
# Select only this camera's images
imgs_dataset = hf_dataset.select_columns(img_key)
# Get episode start and end indices
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item) for item in items]
for future in as_completed(futures):
future.result() # This will raise any exceptions that occurred
def _save_batch_episodes_images(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_indices: list[int],
num_workers: int = 4,
) -> list[float]:
"""Save images from multiple episodes to disk for batch video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_indices: List of episode indices to save
num_workers: Number of threads for parallel image saving
Returns:
List of episode durations in seconds
"""
imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None)
imgs_dataset = hf_dataset.select_columns(img_key)
# Define function to save a single image with global frame index
# Defined once outside the loop to avoid repeated closure creation
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
i, item = i_item_tuple
img = item[img_key_param]
# Use global frame index for naming
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
return i
episode_durations = []
frame_idx = 0
for ep_idx in episode_indices:
# Get episode range
from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx]
to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx]
episode_length = to_idx - from_idx
episode_durations.append(episode_length / dataset.fps)
# Get episode images
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Save images
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items]
for future in as_completed(futures):
future.result()
frame_idx += episode_length
return episode_durations
def _iter_episode_batches(
episode_indices: list[int],
episode_lengths: dict[int, int],
size_per_frame_mb: float,
video_file_size_limit: float,
max_episodes: int | None,
max_frames: int | None,
):
"""Generator that yields batches of episode indices for video encoding.
Groups episodes into batches that respect size and memory constraints:
- Stays under video file size limit
- Respects maximum episodes per batch (if specified)
- Respects maximum frames per batch (if specified)
Args:
episode_indices: List of episode indices to batch
episode_lengths: Dictionary mapping episode index to episode length
size_per_frame_mb: Estimated size per frame in MB
video_file_size_limit: Maximum video file size in MB
max_episodes: Maximum number of episodes per batch (None = no limit)
max_frames: Maximum number of frames per batch (None = no limit)
Yields:
List of episode indices for each batch
"""
batch_episodes = []
estimated_size = 0.0
total_frames = 0
for ep_idx in episode_indices:
ep_length = episode_lengths[ep_idx]
ep_estimated_size = ep_length * size_per_frame_mb
# we check if adding this episode would exceed any constraint
would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit
would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes
would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames
if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames):
# yield current batch before adding this episode
yield batch_episodes
# start a new batch with current episode
batch_episodes = [ep_idx]
estimated_size = ep_estimated_size
total_frames = ep_length
else:
# add to current batch
batch_episodes.append(ep_idx)
estimated_size += ep_estimated_size
total_frames += ep_length
# yield final batch if not empty
if batch_episodes:
yield batch_episodes
def _estimate_frame_size_via_calibration(
dataset: LeRobotDataset,
img_key: str,
episode_indices: list[int],
temp_dir: Path,
fps: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
num_calibration_frames: int = 30,
) -> float:
"""Estimate MB per frame by encoding a small calibration sample.
Encodes a representative sample of frames using the exact codec parameters
to measure actual compression ratio, which is more accurate than heuristics.
Args:
dataset: Source dataset with images.
img_key: Image key to calibrate (e.g., "observation.images.top").
episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding.
vcodec: Video codec (libsvtav1, h264, hevc).
pix_fmt: Pixel format (yuv420p, etc.).
g: GOP size (group of pictures).
crf: Constant Rate Factor (quality).
fast_decode: Fast decode tuning parameter.
num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns:
Estimated size in MB per frame based on actual encoding.
"""
calibration_dir = temp_dir / "calibration" / img_key
calibration_dir.mkdir(parents=True, exist_ok=True)
try:
# Select a representative episode (prefer middle episode if available)
calibration_ep_idx = episode_indices[len(episode_indices) // 2]
# Get episode range
from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx]
to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx]
episode_length = to_idx - from_idx
# Use up to num_calibration_frames from this episode
num_frames = min(num_calibration_frames, episode_length)
# Get frames from dataset
hf_dataset = dataset.hf_dataset.with_format(None)
sample_indices = range(from_idx, from_idx + num_frames)
# Save calibration frames
for i, idx in enumerate(sample_indices):
img = hf_dataset[idx][img_key]
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
# Encode calibration video
calibration_video_path = calibration_dir / "calibration.mp4"
encode_video_frames(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Measure actual compressed size
video_size_bytes = calibration_video_path.stat().st_size
video_size_mb = video_size_bytes / BYTES_PER_MIB
size_per_frame_mb = video_size_mb / num_frames
logging.info(
f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB "
f"= {size_per_frame_mb:.4f} MB/frame for {img_key}"
)
return size_per_frame_mb
finally:
# Clean up calibration files
if calibration_dir.exists():
shutil.rmtree(calibration_dir)
def _copy_data_without_images(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_indices: list[int],
img_keys: list[str],
) -> None:
"""Copy data files without image columns.
Args:
src_dataset: Source dataset
dst_meta: Destination metadata
episode_indices: Episodes to include
img_keys: Image keys to remove
"""
from lerobot.datasets.utils import DATA_DIR
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
episode_set = set(episode_indices)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Filter to only include selected episodes
df = df[df["episode_index"].isin(episode_set)].copy()
if len(df) == 0:
continue
# Remove image columns
columns_to_drop = [col for col in img_keys if col in df.columns]
if columns_to_drop:
df = df.drop(columns=columns_to_drop)
# Get chunk and file indices from path
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
# Write to destination without pandas index
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
# Video conversion constants
BYTES_PER_KIB = 1024
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None,
num_workers: int = 4,
max_episodes_per_batch: int | None = None,
max_frames_per_batch: int | None = None,
) -> LeRobotDataset:
"""Convert image-to-video dataset.
Creates a new LeRobotDataset with images encoded as videos, following the proper
LeRobot dataset structure with videos stored in chunked MP4 files.
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
Returns:
New LeRobotDataset with images encoded as videos
"""
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
)
# Get all image keys
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
if len(img_keys) == 0:
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
# Determine which episodes to process
if episode_indices is None:
episode_indices = list(range(dataset.meta.total_episodes))
if repo_id is None:
repo_id = f"{dataset.repo_id}_video"
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
# Create new features dict, converting image features to video features
new_features = {}
for key, value in dataset.meta.features.items():
if key not in img_keys:
new_features[key] = value
else:
# Convert image key to video format
new_features[key] = value.copy()
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=True,
chunks_size=dataset.meta.chunks_size,
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
)
# Create temporary directory for image extraction
temp_dir = output_dir / "temp_images"
temp_dir.mkdir(parents=True, exist_ok=True)
# Process all episodes and batch encode videos
# Use dictionary for O(1) episode metadata lookups instead of O(n) linear search
all_episode_metadata = {}
fps = int(dataset.fps)
try:
# Build episode metadata entries first
logging.info("Building episode metadata...")
cumulative_frame_idx = 0
for ep_idx in episode_indices:
src_episode = dataset.meta.episodes[ep_idx]
ep_length = src_episode["length"]
ep_meta = {
"episode_index": ep_idx,
"length": ep_length,
"dataset_from_index": cumulative_frame_idx,
"dataset_to_index": cumulative_frame_idx + ep_length,
}
if "data/chunk_index" in src_episode:
ep_meta["data/chunk_index"] = src_episode["data/chunk_index"]
ep_meta["data/file_index"] = src_episode["data/file_index"]
all_episode_metadata[ep_idx] = ep_meta
cumulative_frame_idx += ep_length
# Process each camera and batch encode multiple episodes together
video_file_size_limit = new_meta.video_files_size_in_mb
# Pre-compute episode lengths for batching
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
for img_key in tqdm(img_keys, desc="Processing cameras"):
# Estimate size per frame by encoding a small calibration sample
# This provides accurate compression ratio for the specific codec parameters
size_per_frame_mb = _estimate_frame_size_via_calibration(
dataset=dataset,
img_key=img_key,
episode_indices=episode_indices,
temp_dir=temp_dir,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
)
logging.info(f"Processing camera: {img_key}")
chunk_idx, file_idx = 0, 0
cumulative_timestamp = 0.0
# Process episodes in batches to stay under size limit
for batch_episodes in _iter_episode_batches(
episode_indices=episode_indices,
episode_lengths=episode_lengths,
size_per_frame_mb=size_per_frame_mb,
video_file_size_limit=video_file_size_limit,
max_episodes=max_episodes_per_batch,
max_frames=max_frames_per_batch,
):
total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes)
logging.info(
f" Encoding batch of {len(batch_episodes)} episodes "
f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames"
)
# Save images for all episodes in this batch
imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key
episode_durations = _save_batch_episodes_images(
dataset=dataset,
imgs_dir=imgs_dir,
img_key=img_key,
episode_indices=batch_episodes,
num_workers=num_workers,
)
# Encode all batched episodes into single video
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
)
video_path.parent.mkdir(parents=True, exist_ok=True)
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Clean up temporary images
shutil.rmtree(imgs_dir)
# Update metadata for each episode in the batch
for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True):
from_timestamp = cumulative_timestamp
to_timestamp = cumulative_timestamp + duration
cumulative_timestamp = to_timestamp
# Find episode metadata entry and add video metadata (O(1) dictionary lookup)
ep_meta = all_episode_metadata[ep_idx]
ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx
ep_meta[f"videos/{img_key}/file_index"] = file_idx
ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp
ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp
# Move to next video file for next batch
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size)
cumulative_timestamp = 0.0
# Copy and transform data files (removing image columns)
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
# Save episode metadata
episodes_df = pd.DataFrame(list(all_episode_metadata.values()))
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
episodes_path.parent.mkdir(parents=True, exist_ok=True)
episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info
new_meta.info["total_episodes"] = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
new_meta.info["total_tasks"] = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
for img_key in img_keys:
if not new_meta.features[img_key].get("info", None):
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
write_info(new_meta.info, new_meta.root)
# Copy stats and tasks
if dataset.meta.stats is not None:
# Remove image stats
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
write_stats(new_stats, new_meta.root)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
finally:
# Clean up temporary directory
if temp_dir.exists():
shutil.rmtree(temp_dir)
logging.info(f"Completed converting {dataset.repo_id} to video format")
logging.info(f"New dataset saved to: {output_dir}")
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)

View File

@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
) )
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, REWARD from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
IMAGENET_STATS = { IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -59,12 +59,7 @@ def resolve_delta_timestamps(
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None: if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
# Check for image-specific delta indices first (e.g., for video encoding)
if key.startswith(OBS_IMAGES) and cfg.image_observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.image_observation_delta_indices]
# Fall back to generic observation delta indices for all observations
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0: if len(delta_timestamps) == 0:

View File

@@ -78,7 +78,6 @@ from lerobot.datasets.video_utils import (
from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0" CODEBASE_VERSION = "v3.0"
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
class LeRobotDatasetMetadata: class LeRobotDatasetMetadata:
@@ -541,13 +540,11 @@ class LeRobotDatasetMetadata:
return obj return obj
def _encode_video_worker( def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent img_dir = (root / fpath).parent
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True) encode_video_frames(img_dir, temp_path, fps, overwrite=True)
shutil.rmtree(img_dir) shutil.rmtree(img_dir)
return temp_path return temp_path
@@ -566,7 +563,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
): ):
""" """
2 modes are available for instantiating this class, depending on 2 different use cases: 2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -679,13 +675,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
encoding is CPU-heavy.
""" """
super().__init__() super().__init__()
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
self.repo_id = repo_id self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms self.image_transforms = image_transforms
@@ -697,7 +688,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_indices = None self.delta_indices = None
self.batch_encoding_size = batch_encoding_size self.batch_encoding_size = batch_encoding_size
self.episodes_since_last_encoding = 0 self.episodes_since_last_encoding = 0
self.vcodec = vcodec
# Unused attributes # Unused attributes
self.image_writer = None self.image_writer = None
@@ -935,30 +925,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
else: else:
return get_hf_features_from_features(self.features) return get_hf_features_from_features(self.features)
def _get_query_indices( def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
self, abs_idx: int, ep_idx: int
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
"""Compute query indices for delta timestamps.
Args:
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
ep_idx: The episode index.
Returns:
A tuple of (query_indices, padding) where:
- query_indices: Dict mapping keys to lists of absolute indices to query
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
"""
ep = self.meta.episodes[ep_idx] ep = self.meta.episodes[ep_idx]
ep_start = ep["dataset_from_index"] ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"] ep_end = ep["dataset_to_index"]
query_indices = { query_indices = {
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
padding = { # Pad values outside of current episode range padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor( f"{key}_is_pad": torch.BoolTensor(
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
) )
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
@@ -1050,12 +1027,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._ensure_hf_dataset_loaded() self._ensure_hf_dataset_loaded()
item = self.hf_dataset[idx] item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item() ep_idx = item["episode_index"].item()
# Use the absolute index from the dataset for delta timestamp calculations
abs_idx = item["index"].item()
query_indices = None query_indices = None
if self.delta_indices is not None: if self.delta_indices is not None:
query_indices, padding = self._get_query_indices(abs_idx, ep_idx) query_indices, padding = self._get_query_indices(idx, ep_idx)
query_result = self._query_hf_dataset(query_indices) query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding} item = {**item, **padding}
for key, val in query_result.items(): for key, val in query_result.items():
@@ -1236,7 +1211,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index, episode_index,
self.root, self.root,
self.fps, self.fps,
self.vcodec,
): video_key ): video_key
for video_key in self.meta.video_keys for video_key in self.meta.video_keys
} }
@@ -1513,7 +1487,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
if isinstance(episode_index, np.ndarray): if isinstance(episode_index, np.ndarray):
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for cam_key in self.meta.image_keys: for cam_key in self.meta.camera_keys:
img_dir = self._get_image_file_dir(episode_index, cam_key) img_dir = self._get_image_file_dir(episode_index, cam_key)
if img_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(img_dir) shutil.rmtree(img_dir)
@@ -1552,7 +1526,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading. since video encoding with ffmpeg is already using multithreading.
""" """
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec) return _encode_video_worker(video_key, episode_index, self.root, self.fps)
@classmethod @classmethod
def create( def create(
@@ -1568,11 +1542,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0, image_writer_threads: int = 0,
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create( obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id, repo_id=repo_id,
@@ -1589,7 +1560,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_writer = None obj.image_writer = None
obj.batch_encoding_size = batch_encoding_size obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0 obj.episodes_since_last_encoding = 0
obj.vcodec = vcodec
if image_writer_processes or image_writer_threads: if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads) obj.start_image_writer(image_writer_processes, image_writer_threads)

View File

@@ -18,12 +18,12 @@ from typing import Any
from lerobot.configs.types import PipelineFeatureType from lerobot.configs.types import PipelineFeatureType
from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation from lerobot.processor import DataProcessorPipeline
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
def create_initial_features( def create_initial_features(
action: RobotAction | None = None, observation: RobotObservation | None = None action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None
) -> dict[PipelineFeatureType, dict[str, Any]]: ) -> dict[PipelineFeatureType, dict[str, Any]]:
""" """
Creates the initial features dict for the dataset from action and observation specs. Creates the initial features dict for the dataset from action and observation specs.

View File

@@ -1172,21 +1172,12 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
) )
def to_parquet_with_hf_images( def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned. This way, it can be loaded by HF dataset and correctly formatted images are returned.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
""" """
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
ds.to_parquet(path)
def item_to_torch(item: dict) -> dict: def item_to_torch(item: dict) -> dict:

View File

@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .configs import AlohaEnv, EnvConfig, HubEnvConfig, PushtEnv # noqa: F401 from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import abc import abc
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field
from typing import Any from typing import Any
import draccus import draccus
@@ -68,22 +68,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
raise NotImplementedError() raise NotImplementedError()
@dataclass
class HubEnvConfig(EnvConfig):
"""Base class for environments that delegate creation to a hub-hosted make_env.
Hub environments download and execute remote code from the HF Hub.
The hub_path points to a repository containing an env.py with a make_env function.
"""
hub_path: str | None = None # required: e.g., "username/repo" or "username/repo@branch:file.py"
@property
def gym_kwargs(self) -> dict:
# Not used for hub environments - the hub's make_env handles everything
return {}
@EnvConfig.register_subclass("aloha") @EnvConfig.register_subclass("aloha")
@dataclass @dataclass
class AlohaEnv(EnvConfig): class AlohaEnv(EnvConfig):
@@ -384,71 +368,3 @@ class MetaworldEnv(EnvConfig):
"obs_type": self.obs_type, "obs_type": self.obs_type,
"render_mode": self.render_mode, "render_mode": self.render_mode,
} }
@EnvConfig.register_subclass("isaaclab_arena")
@dataclass
class IsaaclabArenaEnv(HubEnvConfig):
hub_path: str = "nvidia/isaaclab-arena-envs"
episode_length: int = 300
num_envs: int = 1
embodiment: str | None = "gr1_pink"
object: str | None = "power_drill"
mimic: bool = False
teleop_device: str | None = None
seed: int | None = 42
device: str | None = "cuda:0"
disable_fabric: bool = False
enable_cameras: bool = False
headless: bool = False
enable_pinocchio: bool = True
environment: str | None = "gr1_microwave"
task: str | None = "Reach out to the microwave and open it."
state_dim: int = 54
action_dim: int = 36
camera_height: int = 512
camera_width: int = 512
video: bool = False
video_length: int = 100
video_interval: int = 200
# Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos"
state_keys: str = "robot_joint_pos"
# Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb"
# Set to None or "" for environments without cameras
camera_keys: str | None = None
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)
kwargs: dict | None = None
def __post_init__(self):
if self.kwargs:
# dynamically convert kwargs to fields in the dataclass
# NOTE! the new fields will not bee seen by the dataclass repr
field_names = {f.name for f in fields(self)}
for key, value in self.kwargs.items():
if key not in field_names and key != "kwargs":
setattr(self, key, value)
self.kwargs = None
# Set action feature
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
self.features_map[ACTION] = ACTION
# Set state feature
self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,))
self.features_map[OBS_STATE] = OBS_STATE
# Add camera features for each camera key
if self.enable_cameras and self.camera_keys:
for cam_key in self.camera_keys.split(","):
cam_key = cam_key.strip()
if cam_key:
self.features[cam_key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(self.camera_height, self.camera_width, 3),
)
self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}"
@property
def gym_kwargs(self) -> dict:
return {}

View File

@@ -20,11 +20,11 @@ import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -73,26 +73,6 @@ def make_env_pre_post_processors(
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep()) preprocessor_steps.append(LiberoProcessorStep())
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
# Parse comma-separated keys (handle None for state-based policies)
if env_cfg.state_keys:
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
else:
state_keys = ()
if env_cfg.camera_keys:
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
else:
camera_keys = ()
if not state_keys and not camera_keys:
raise ValueError("At least one of state_keys or camera_keys must be specified.")
preprocessor_steps.append(
IsaaclabArenaProcessorStep(
state_keys=state_keys,
camera_keys=camera_keys,
)
)
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps) preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps) postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
@@ -118,6 +98,7 @@ def make_env(
hub_cache_dir (str | None): Optional cache path for downloaded hub files. hub_cache_dir (str | None): Optional cache path for downloaded hub files.
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub. trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
Default False — must be set to True to import/exec hub `env.py`. Default False — must be set to True to import/exec hub `env.py`.
Raises: Raises:
ValueError: if n_envs < 1 ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not installed ModuleNotFoundError: If the requested env package is not installed
@@ -131,35 +112,19 @@ def make_env(
""" """
# if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py") # if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
# simplified: only support hub-provided `make_env` # simplified: only support hub-provided `make_env`
# TODO: (jadechoghari): deprecate string API and remove this check
if isinstance(cfg, str): if isinstance(cfg, str):
hub_path: str | None = cfg
elif isinstance(cfg, HubEnvConfig):
hub_path = cfg.hub_path
else:
hub_path = None
# If hub_path is set, download and call hub-provided `make_env`
if hub_path:
# _download_hub_file will raise the same RuntimeError if trust_remote_code is False # _download_hub_file will raise the same RuntimeError if trust_remote_code is False
repo_id, file_path, local_file, revision = _download_hub_file( repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
hub_path, trust_remote_code, hub_cache_dir
)
# import and surface clear import errors # import and surface clear import errors
module = _import_hub_module(local_file, repo_id) module = _import_hub_module(local_file, repo_id)
# call the hub-provided make_env # call the hub-provided make_env
env_cfg = None if isinstance(cfg, str) else cfg raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg)
# normalize the return into {suite: {task_id: vec_env}} # normalize the return into {suite: {task_id: vec_env}}
return _normalize_hub_result(raw_result) return _normalize_hub_result(raw_result)
# At this point, cfg must be an EnvConfig (not a string) since hub_path would have been set otherwise
if isinstance(cfg, str):
raise TypeError("cfg should be an EnvConfig at this point")
if n_envs < 1: if n_envs < 1:
raise ValueError("`n_envs` must be at least 1") raise ValueError("`n_envs` must be at least 1")

View File

@@ -29,8 +29,6 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv from libero.libero.envs import OffScreenRenderEnv
from lerobot.processor import RobotObservation
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings.""" """Normalize camera_name into a non-empty list of strings."""
@@ -239,7 +237,7 @@ class LiberoEnv(gym.Env):
env.reset() env.reset()
return env return env
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation: def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
images = {} images = {}
for camera_name in self.camera_name: for camera_name in self.camera_name:
image = raw_obs[camera_name] image = raw_obs[camera_name]
@@ -293,9 +291,9 @@ class LiberoEnv(gym.Env):
def reset(self, seed=None, **kwargs): def reset(self, seed=None, **kwargs):
super().reset(seed=seed) super().reset(seed=seed)
self._env.seed(seed) self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None: if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) self._env.set_init_state(self._init_states[self._init_state_id])
raw_obs = self._env.reset()
# After reset, objects may be unstable (slightly floating, intersecting, etc.). # After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles. # Step the simulator with a no-op action for a few frames so everything settles.
@@ -315,7 +313,7 @@ class LiberoEnv(gym.Env):
info = {"is_success": False} info = {"is_success": False}
return observation, info return observation, info
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
if action.ndim != 1: if action.ndim != 1:
raise ValueError( raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), " f"Expected action to be 1-D (shape (action_dim,)), "

View File

@@ -25,8 +25,6 @@ import metaworld.policies as policies
import numpy as np import numpy as np
from gymnasium import spaces from gymnasium import spaces
from lerobot.processor import RobotObservation
# ---- Load configuration data from the external JSON file ---- # ---- Load configuration data from the external JSON file ----
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
try: try:
@@ -163,7 +161,7 @@ class MetaworldEnv(gym.Env):
env._freeze_rand_vec = False # otherwise no randomization env._freeze_rand_vec = False # otherwise no randomization
return env return env
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation: def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
image = None image = None
if self._env is not None: if self._env is not None:
image = self._env.render() image = self._env.render()
@@ -198,7 +196,7 @@ class MetaworldEnv(gym.Env):
self, self,
seed: int | None = None, seed: int | None = None,
**kwargs, **kwargs,
) -> tuple[RobotObservation, dict[str, Any]]: ) -> tuple[dict[str, Any], dict[str, Any]]:
""" """
Reset the environment to its initial state. Reset the environment to its initial state.
@@ -206,7 +204,7 @@ class MetaworldEnv(gym.Env):
seed (Optional[int]): Random seed for environment initialization. seed (Optional[int]): Random seed for environment initialization.
Returns: Returns:
observation (RobotObservation): The initial formatted observation. observation (Dict[str, Any]): The initial formatted observation.
info (Dict[str, Any]): Additional info about the reset state. info (Dict[str, Any]): Additional info about the reset state.
""" """
super().reset(seed=seed) super().reset(seed=seed)
@@ -218,7 +216,7 @@ class MetaworldEnv(gym.Env):
info = {"is_success": False} info = {"is_success": False}
return observation, info return observation, info
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
""" """
Perform one environment step. Perform one environment step.
@@ -226,7 +224,7 @@ class MetaworldEnv(gym.Env):
action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,). action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
Returns: Returns:
observation (RobotObservation): The formatted observation after the step. observation (Dict[str, Any]): The formatted observation after the step.
reward (float): The scalar reward for this step. reward (float): The scalar reward for this step.
terminated (bool): Whether the episode terminated successfully. terminated (bool): Whether the episode terminated successfully.
truncated (bool): Whether the episode was truncated due to a time limit. truncated (bool): Whether the episode was truncated due to a time limit.

View File

@@ -29,7 +29,6 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig from lerobot.envs.configs import EnvConfig
from lerobot.processor import RobotObservation
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape from lerobot.utils.utils import get_channel_first_image_shape
@@ -47,7 +46,7 @@ def _convert_nested_dict(d):
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(jadechoghari, imstevenpmwork): refactor this to use features from the environment (no hardcoding) # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
Args: Args:
observation: Dictionary of observation batches from a Gym vector environment. observation: Dictionary of observation batches from a Gym vector environment.
@@ -99,19 +98,11 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "robot_state" in observations: if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"]) return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
# Handle IsaacLab Arena format: observations have 'policy' and 'camera_obs' keys
if "policy" in observations:
return_observations[f"{OBS_STR}.policy"] = observations["policy"]
if "camera_obs" in observations:
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
return return_observations return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(jadechoghari, imstevenpmwork): remove this hardcoding of keys and just use the nested keys as is # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies) # (need to also refactor preprocess_observation and externalize normalization from policies)
policy_features = {} policy_features = {}
for key, ft in env_cfg.features.items(): for key, ft in env_cfg.features.items():
@@ -153,7 +144,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
) )
def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation: def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute.""" """Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"): if hasattr(env.envs[0], "task_description"):
task_result = env.call("task_description") task_result = env.call("task_description")
@@ -311,7 +302,7 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any:
return module return module
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None) -> Any: def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
""" """
Ensure module exposes make_env and call it. Ensure module exposes make_env and call it.
""" """
@@ -320,11 +311,7 @@ def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfi
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`." f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
) )
entry_fn = module.make_env entry_fn = module.make_env
# Only pass cfg if it's not None (i.e., when an EnvConfig was provided, not a string hub ID) return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
if cfg is not None:
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg)
else:
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]: def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:

View File

@@ -205,7 +205,6 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits # Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = { STS_SMS_SERIES_ENCODINGS_TABLE = {
"Present_Load": 10,
"Homing_Offset": 11, "Homing_Offset": 11,
"Goal_Position": 15, "Goal_Position": 15,
"Goal_Velocity": 15, "Goal_Velocity": 15,

View File

@@ -32,7 +32,7 @@ import serial
from deepdiff import DeepDiff from deepdiff import DeepDiff
from tqdm import tqdm from tqdm import tqdm
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up from lerobot.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int NameOrID: TypeAlias = str | int
@@ -411,7 +411,6 @@ class MotorsBus(abc.ABC):
"""bool: `True` if the underlying serial port is open.""" """bool: `True` if the underlying serial port is open."""
return self.port_handler.is_open return self.port_handler.is_open
@check_if_already_connected
def connect(self, handshake: bool = True) -> None: def connect(self, handshake: bool = True) -> None:
"""Open the serial port and initialise communication. """Open the serial port and initialise communication.
@@ -423,6 +422,10 @@ class MotorsBus(abc.ABC):
DeviceAlreadyConnectedError: The port is already open. DeviceAlreadyConnectedError: The port is already open.
ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed. ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed.
""" """
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
)
self._connect(handshake) self._connect(handshake)
self.set_timeout() self.set_timeout()
@@ -444,7 +447,6 @@ class MotorsBus(abc.ABC):
def _handshake(self) -> None: def _handshake(self) -> None:
pass pass
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None: def disconnect(self, disable_torque: bool = True) -> None:
"""Close the serial port (optionally disabling torque first). """Close the serial port (optionally disabling torque first).
@@ -453,6 +455,10 @@ class MotorsBus(abc.ABC):
closing the port. This can prevent damaging motors if they are left applying resisting torque closing the port. This can prevent damaging motors if they are left applying resisting torque
after disconnect. after disconnect.
""" """
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
)
if disable_torque: if disable_torque:
self.port_handler.clearPort() self.port_handler.clearPort()
@@ -901,7 +907,6 @@ class MotorsBus(abc.ABC):
""" """
pass pass
@check_if_not_connected
def read( def read(
self, self,
data_name: str, data_name: str,
@@ -922,6 +927,10 @@ class MotorsBus(abc.ABC):
Returns: Returns:
Value: Raw or normalised value depending on *normalize*. Value: Raw or normalised value depending on *normalize*.
""" """
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
id_ = self.motors[motor].id id_ = self.motors[motor].id
model = self.motors[motor].model model = self.motors[motor].model
@@ -972,7 +981,6 @@ class MotorsBus(abc.ABC):
return value, comm, error return value, comm, error
@check_if_not_connected
def write( def write(
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
) -> None: ) -> None:
@@ -991,6 +999,10 @@ class MotorsBus(abc.ABC):
normalize (bool, optional): Enable or disable normalisation. Defaults to `True`. normalize (bool, optional): Enable or disable normalisation. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`. num_retry (int, optional): Retry attempts. Defaults to `0`.
""" """
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
id_ = self.motors[motor].id id_ = self.motors[motor].id
model = self.motors[motor].model model = self.motors[motor].model
@@ -1032,7 +1044,6 @@ class MotorsBus(abc.ABC):
return comm, error return comm, error
@check_if_not_connected
def sync_read( def sync_read(
self, self,
data_name: str, data_name: str,
@@ -1052,6 +1063,10 @@ class MotorsBus(abc.ABC):
Returns: Returns:
dict[str, Value]: Mapping *motor name → value*. dict[str, Value]: Mapping *motor name → value*.
""" """
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
self._assert_protocol_is_compatible("sync_read") self._assert_protocol_is_compatible("sync_read")
@@ -1124,7 +1139,6 @@ class MotorsBus(abc.ABC):
# for id_ in motor_ids: # for id_ in motor_ids:
# value = self.sync_reader.getData(id_, address, length) # value = self.sync_reader.getData(id_, address, length)
@check_if_not_connected
def sync_write( def sync_write(
self, self,
data_name: str, data_name: str,
@@ -1146,6 +1160,10 @@ class MotorsBus(abc.ABC):
normalize (bool, optional): If `True` (default) convert values from the user range to raw units. normalize (bool, optional): If `True` (default) convert values from the user range to raw units.
num_retry (int, optional): Retry attempts. Defaults to `0`. num_retry (int, optional): Retry attempts. Defaults to `0`.
""" """
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
ids_values = self._get_ids_values_dict(values) ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values] models = [self._id_to_model(id_) for id_ in ids_values]

View File

@@ -16,7 +16,6 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig from .groot.configuration_groot import GrootConfig as GrootConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .smolvla.processor_smolvla import SmolVLANewLineProcessor
@@ -30,7 +29,6 @@ __all__ = [
"DiffusionConfig", "DiffusionConfig",
"PI0Config", "PI0Config",
"PI05Config", "PI05Config",
"PI0FastConfig",
"SmolVLAConfig", "SmolVLAConfig",
"SARMConfig", "SARMConfig",
"TDMPCConfig", "TDMPCConfig",

View File

@@ -35,7 +35,6 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig
@@ -52,11 +51,7 @@ from lerobot.processor.converters import (
transition_to_batch, transition_to_batch,
transition_to_policy_action, transition_to_policy_action,
) )
from lerobot.utils.constants import ( from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
ACTION,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def get_policy_class(name: str) -> type[PreTrainedPolicy]: def get_policy_class(name: str) -> type[PreTrainedPolicy]:
@@ -68,7 +63,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args: Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "pi05_video", "sac", "reward_classifier", "smolvla", "wall_x". "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns: Returns:
The policy class corresponding to the given name. The policy class corresponding to the given name.
@@ -96,18 +91,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.pi0.modeling_pi0 import PI0Policy from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy return PI0Policy
elif name == "pi0_fast":
from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
return PI0FastPolicy
elif name == "pi05": elif name == "pi05":
from lerobot.policies.pi05.modeling_pi05 import PI05Policy from lerobot.policies.pi05.modeling_pi05 import PI05Policy
return PI05Policy return PI05Policy
elif name == "pi05_video":
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
return PI05VideoPolicy
elif name == "sac": elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -152,7 +139,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args: Args:
policy_type: The type of the policy. Supported types include "tdmpc", policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "pi05_video", "sac", "smolvla", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x". "reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor. **kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -174,8 +161,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs) return PI0Config(**kwargs)
elif policy_type == "pi05": elif policy_type == "pi05":
return PI05Config(**kwargs) return PI05Config(**kwargs)
elif policy_type == "pi05_video":
return PI05VideoConfig(**kwargs)
elif policy_type == "sac": elif policy_type == "sac":
return SACConfig(**kwargs) return SACConfig(**kwargs)
elif policy_type == "smolvla": elif policy_type == "smolvla":
@@ -261,7 +246,7 @@ def make_pre_post_processors(
} }
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats # Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features[ACTION].shape[0] env_action_dim = policy_cfg.output_features["action"].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = { postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"), "stats": kwargs.get("dataset_stats"),
"normalize_min_max": True, "normalize_min_max": True,
@@ -340,14 +325,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"), dataset_stats=kwargs.get("dataset_stats"),
) )
elif isinstance(policy_cfg, PI05VideoConfig):
from lerobot.policies.videovla.processor_pi05 import make_pi05_video_pre_post_processors
processors = make_pi05_video_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig): elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
@@ -494,40 +471,11 @@ def make_policy(
if ds_meta is not None: if ds_meta is not None:
kwargs["dataset_meta"] = ds_meta kwargs["dataset_meta"] = ds_meta
if not cfg.pretrained_path and cfg.use_peft: if cfg.pretrained_path:
raise ValueError(
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
"the PEFT config parameters to be set. For training with PEFT, see `lerobot_train.py` on how to do that."
)
if cfg.pretrained_path and not cfg.use_peft:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary). # hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path kwargs["pretrained_name_or_path"] = cfg.pretrained_path
policy = policy_cls.from_pretrained(**kwargs) policy = policy_cls.from_pretrained(**kwargs)
elif cfg.pretrained_path and cfg.use_peft:
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
# of the adapter and the adapter's config contains the path to the base policy. So we need the
# adapter config first, then load the correct policy and then apply PEFT.
from peft import PeftConfig, PeftModel
logging.info("Loading policy's PEFT adapter.")
peft_pretrained_path = cfg.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
if not kwargs["pretrained_name_or_path"]:
# This means that there's a bug or we trained a policy from scratch using PEFT.
# It is more likely that this is a bug so we'll raise an error.
raise ValueError(
"No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which "
"the adapter was trained."
)
policy = policy_cls.from_pretrained(**kwargs)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else: else:
# Make a fresh policy. # Make a fresh policy.
policy = policy_cls(**kwargs) policy = policy_cls(**kwargs)

View File

@@ -20,7 +20,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
@PreTrainedConfig.register_subclass("groot") @PreTrainedConfig.register_subclass("groot")
@@ -138,14 +137,14 @@ class GrootConfig(PreTrainedConfig):
"No features of type FeatureType.VISUAL found in input_features." "No features of type FeatureType.VISUAL found in input_features."
) )
if OBS_STATE not in self.input_features: if "observation.state" not in self.input_features:
state_feature = PolicyFeature( state_feature = PolicyFeature(
type=FeatureType.STATE, type=FeatureType.STATE,
shape=(self.max_state_dim,), shape=(self.max_state_dim,),
) )
self.input_features[OBS_STATE] = state_feature self.input_features["observation.state"] = state_feature
else: else:
state_shape = self.input_features[OBS_STATE].shape state_shape = self.input_features["observation.state"].shape
state_dim = state_shape[0] if state_shape else 0 state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim: if state_dim > self.max_state_dim:
raise ValueError( raise ValueError(
@@ -153,14 +152,14 @@ class GrootConfig(PreTrainedConfig):
f"Either reduce state dimension or increase max_state_dim in config." f"Either reduce state dimension or increase max_state_dim in config."
) )
if ACTION not in self.output_features: if "action" not in self.output_features:
action_feature = PolicyFeature( action_feature = PolicyFeature(
type=FeatureType.ACTION, type=FeatureType.ACTION,
shape=(self.max_action_dim,), shape=(self.max_action_dim,),
) )
self.output_features[ACTION] = action_feature self.output_features["action"] = action_feature
else: else:
action_shape = self.output_features[ACTION].shape action_shape = self.output_features["action"].shape
action_dim = action_shape[0] if action_shape else 0 action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim: if action_dim > self.max_action_dim:
raise ValueError( raise ValueError(

View File

@@ -46,7 +46,7 @@ from lerobot.policies.groot.action_head.flow_matching_action_head import (
FlowmatchingActionHeadConfig, FlowmatchingActionHeadConfig,
) )
from lerobot.policies.groot.utils import ensure_eagle_cache_ready from lerobot.policies.groot.utils import ensure_eagle_cache_ready
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME from lerobot.utils.constants import HF_LEROBOT_HOME
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve()) DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5" DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
@@ -227,8 +227,8 @@ class GR00TN15(PreTrainedModel):
detected_error = False detected_error = False
error_msg = ERROR_MSG error_msg = ERROR_MSG
if ACTION in inputs: if "action" in inputs:
action = inputs[ACTION] action = inputs["action"]
# In inference, action may be omitted or None; validate only when it's a tensor. # In inference, action may be omitted or None; validate only when it's a tensor.
if action is None: if action is None:
pass # allow None during inference pass # allow None during inference

View File

@@ -32,22 +32,15 @@ Notes:
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below. from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
""" """
import builtins
import os import os
from collections import deque from collections import deque
from pathlib import Path
from typing import TypeVar
import torch import torch
from torch import Tensor from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1 import GR00TN15 from lerobot.policies.groot.groot_n1 import GR00TN15
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES
T = TypeVar("T", bound="GrootPolicy")
class GrootPolicy(PreTrainedPolicy): class GrootPolicy(PreTrainedPolicy):
@@ -96,129 +89,6 @@ class GrootPolicy(PreTrainedPolicy):
"""Reset policy state when environment resets.""" """Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: GrootConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = True,
**kwargs,
) -> T:
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
config: Optional GrootConfig. If None, loads from checkpoint or creates default
force_download: Force download even if cached
resume_download: Resume interrupted download
proxies: Proxy settings
token: HuggingFace authentication token
cache_dir: Cache directory path
local_files_only: Only use local files
revision: Specific model revision
strict: Strict state dict loading
**kwargs: Additional arguments (passed to config)
Returns:
Initialized GrootPolicy instance with loaded model
"""
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
print(
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
model_id = str(pretrained_name_or_path)
is_finetuned_checkpoint = False
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
try:
if os.path.isdir(model_id):
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
else:
# Try to download the safetensors file to check if it exists
try:
hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=False, # Just check, don't force download
proxies=proxies,
token=token,
local_files_only=local_files_only,
)
is_finetuned_checkpoint = True
except HfHubHTTPError:
is_finetuned_checkpoint = False
except Exception:
is_finetuned_checkpoint = False
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
# This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
# Create default config with the pretrained path
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
# These are placeholders - actual robot features come from the preprocessor
if not config.input_features:
config.input_features = {
f"{OBS_IMAGES}.camera": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Default image size from config
),
}
else:
# Override the base_model_path with the provided path
config.base_model_path = str(pretrained_name_or_path)
# Pass through any additional config overrides from kwargs
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
policy.eval()
return policy
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
return self.parameters() return self.parameters()
@@ -277,7 +147,7 @@ class GrootPolicy(PreTrainedPolicy):
actions = outputs.get("action_pred") actions = outputs.get("action_pred")
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features["action"].shape[0]
actions = actions[:, :, :original_action_dim] actions = actions[:, :, :original_action_dim]
return actions return actions

View File

@@ -51,11 +51,7 @@ from lerobot.processor.converters import (
) )
from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION,
HF_LEROBOT_HOME, HF_LEROBOT_HOME,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME,
) )
@@ -111,9 +107,9 @@ def make_groot_pre_post_processors(
# Define feature specs for optional normalization steps # Define feature specs for optional normalization steps
_features: dict[str, PolicyFeature] = { _features: dict[str, PolicyFeature] = {
# Observation features (only add those we may normalize) # Observation features (only add those we may normalize)
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)), "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
# Action feature # Action feature
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
} }
# Normalize STATE and ACTION with min_max (SO100-like default) # Normalize STATE and ACTION with min_max (SO100-like default)
@@ -124,7 +120,7 @@ def make_groot_pre_post_processors(
# Determine env action dimension from config (simple, object-like PolicyFeature) # Determine env action dimension from config (simple, object-like PolicyFeature)
try: try:
env_action_dim = int(config.output_features[ACTION].shape[0]) env_action_dim = int(config.output_features["action"].shape[0])
except Exception: except Exception:
env_action_dim = 0 env_action_dim = 0
@@ -272,9 +268,9 @@ class GrootPackInputsStep(ProcessorStep):
return torch.where(mask, mapped, torch.zeros_like(mapped)) return torch.where(mask, mapped, torch.zeros_like(mapped))
# 1) Video (B, T=1, V, H, W, C) uint8 # 1) Video (B, T=1, V, H, W, C) uint8
img_keys = sorted([k for k in obs if k.startswith(OBS_IMAGES)]) img_keys = sorted([k for k in obs if k.startswith("observation.images.")])
if not img_keys and OBS_IMAGE in obs: if not img_keys and "observation.image" in obs:
img_keys = [OBS_IMAGE] img_keys = ["observation.image"]
if img_keys: if img_keys:
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys] cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
video = np.stack(cams, axis=1) # (B, V, H, W, C) video = np.stack(cams, axis=1) # (B, V, H, W, C)
@@ -298,14 +294,14 @@ class GrootPackInputsStep(ProcessorStep):
comp["language"] = lang comp["language"] = lang
# 3) State/state_mask -> (B, 1, max_state_dim) # 3) State/state_mask -> (B, 1, max_state_dim)
if OBS_STATE in obs: if "observation.state" in obs:
state = obs[OBS_STATE] # (B, D) state = obs["observation.state"] # (B, D)
if state.dim() != 2: if state.dim() != 2:
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}") raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
bsz, d = state.shape bsz, d = state.shape
# Normalize BEFORE padding # Normalize BEFORE padding
if self.normalize_min_max: if self.normalize_min_max:
state = _min_max_norm(state, OBS_STATE) state = _min_max_norm(state, "observation.state")
state = state.unsqueeze(1) # (B, 1, D) state = state.unsqueeze(1) # (B, 1, D)
if d > self.max_state_dim: if d > self.max_state_dim:
state = state[:, :, : self.max_state_dim] state = state[:, :, : self.max_state_dim]
@@ -324,11 +320,11 @@ class GrootPackInputsStep(ProcessorStep):
# Normalize BEFORE temporal expansion/padding # Normalize BEFORE temporal expansion/padding
if self.normalize_min_max: if self.normalize_min_max:
if action.dim() == 2: if action.dim() == 2:
action = _min_max_norm(action, ACTION) action = _min_max_norm(action, "action")
elif action.dim() == 3: elif action.dim() == 3:
b, t, d = action.shape b, t, d = action.shape
flat = action.reshape(b * t, d) flat = action.reshape(b * t, d)
flat = _min_max_norm(flat, ACTION) flat = _min_max_norm(flat, "action")
action = flat.view(b, t, d) action = flat.view(b, t, d)
if action.dim() == 2: if action.dim() == 2:
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1) action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
@@ -594,7 +590,7 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
# forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0 # forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0
# inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min # inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min
if self.normalize_min_max and self.stats is not None: if self.normalize_min_max and self.stats is not None:
stats_k = self.stats.get(ACTION, {}) stats_k = self.stats.get("action", {})
d = action.shape[-1] d = action.shape[-1]
min_v = torch.as_tensor( min_v = torch.as_tensor(
stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device

View File

@@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import OBS_IMAGES
DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
@@ -76,10 +76,6 @@ class PI0Config(PreTrainedConfig):
compile_mode: str = "max-autotune" # Torch compile mode compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect) device: str | None = None # Device to use for the model (None = auto-detect)
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
# Optimizer settings: see openpi `AdamW`` # Optimizer settings: see openpi `AdamW``
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_betas: tuple[float, float] = (0.9, 0.95)
@@ -124,19 +120,19 @@ class PI0Config(PreTrainedConfig):
) )
self.input_features[key] = empty_camera self.input_features[key] = empty_camera
if OBS_STATE not in self.input_features: if "observation.state" not in self.input_features:
state_feature = PolicyFeature( state_feature = PolicyFeature(
type=FeatureType.STATE, type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim shape=(self.max_state_dim,), # Padded to max_state_dim
) )
self.input_features[OBS_STATE] = state_feature self.input_features["observation.state"] = state_feature
if ACTION not in self.output_features: if "action" not in self.output_features:
action_feature = PolicyFeature( action_feature = PolicyFeature(
type=FeatureType.ACTION, type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim shape=(self.max_action_dim,), # Padded to max_action_dim
) )
self.output_features[ACTION] = action_feature self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig: def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig( return AdamWConfig(

View File

@@ -339,14 +339,10 @@ class PaliGemmaWithExpertModel(
use_adarms=None, use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16", precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE, image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
): ):
if use_adarms is None: if use_adarms is None:
use_adarms = [False, False] use_adarms = [False, False]
super().__init__() super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001 vlm_config_hf._vocab_size = 257152 # noqa: SLF001
@@ -387,7 +383,6 @@ class PaliGemmaWithExpertModel(
self.gemma_expert.model.embed_tokens = None self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision) self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16": if precision == "bfloat16":
@@ -411,23 +406,6 @@ class PaliGemmaWithExpertModel(
if any(selector in name for selector in params_to_keep_float32): if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32) param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor): def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image) return self.paligemma.model.get_image_features(image)
@@ -555,8 +533,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_adarms=[False, False], use_adarms=[False, False],
precision=config.dtype, precision=config.dtype,
image_size=config.image_resolution[0], image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
) )
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -1297,14 +1273,3 @@ class PI0Policy(PreTrainedPolicy):
loss = losses.mean() loss = losses.mean()
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for PI0 fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],
}

View File

@@ -21,7 +21,6 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
@@ -77,10 +76,6 @@ class PI05Config(PreTrainedConfig):
compile_mode: str = "max-autotune" # Torch compile mode compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect) device: str | None = None # Device to use for the model (None = auto-detect)
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
# Optimizer settings: see openpi `AdamW` # Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_betas: tuple[float, float] = (0.9, 0.95)
@@ -118,26 +113,26 @@ class PI05Config(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate and set up input/output features.""" """Validate and set up input/output features."""
for i in range(self.empty_cameras): for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}" key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature( empty_camera = PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution shape=(3, *self.image_resolution), # Use configured image resolution
) )
self.input_features[key] = empty_camera self.input_features[key] = empty_camera
if OBS_STATE not in self.input_features: if "observation.state" not in self.input_features:
state_feature = PolicyFeature( state_feature = PolicyFeature(
type=FeatureType.STATE, type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim shape=(self.max_state_dim,), # Padded to max_state_dim
) )
self.input_features[OBS_STATE] = state_feature self.input_features["observation.state"] = state_feature
if ACTION not in self.output_features: if "action" not in self.output_features:
action_feature = PolicyFeature( action_feature = PolicyFeature(
type=FeatureType.ACTION, type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim shape=(self.max_action_dim,), # Padded to max_action_dim
) )
self.output_features[ACTION] = action_feature self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig: def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig( return AdamWConfig(

View File

@@ -337,14 +337,10 @@ class PaliGemmaWithExpertModel(
use_adarms=None, use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16", precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE, image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
): ):
if use_adarms is None: if use_adarms is None:
use_adarms = [False, False] use_adarms = [False, False]
super().__init__() super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001 vlm_config_hf._vocab_size = 257152 # noqa: SLF001
@@ -385,7 +381,6 @@ class PaliGemmaWithExpertModel(
self.gemma_expert.model.embed_tokens = None self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision) self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16": if precision == "bfloat16":
@@ -409,23 +404,6 @@ class PaliGemmaWithExpertModel(
if any(selector in name for selector in params_to_keep_float32): if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32) param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor): def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image) return self.paligemma.model.get_image_features(image)
@@ -460,8 +438,8 @@ class PaliGemmaWithExpertModel(
inputs_embeds=inputs_embeds[1], inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
use_cache=False, past_key_values=past_key_values,
past_key_values=None, #jadechoghari use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None, adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
) )
suffix_output = suffix_output.last_hidden_state suffix_output = suffix_output.last_hidden_state
@@ -553,8 +531,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_adarms=[False, True], use_adarms=[False, True],
precision=config.dtype, precision=config.dtype,
image_size=config.image_resolution[0], image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
) )
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -575,13 +551,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
# try: try:
# from transformers.models.siglip import check from transformers.models.siglip import check
# if not check.check_whether_transformers_replace_is_installed_correctly(): if not check.check_whether_transformers_replace_is_installed_correctly():
# raise ValueError(msg) raise ValueError(msg)
# except ImportError: except ImportError:
# raise ValueError(msg) from None raise ValueError(msg) from None
def gradient_checkpointing_enable(self): def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization.""" """Enable gradient checkpointing for memory optimization."""
@@ -1270,14 +1246,3 @@ class PI05Policy(PreTrainedPolicy):
loss = losses.mean() loss = losses.mean()
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for PI0.5 fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,177 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@ProcessorStepRegistry.register(name="pi0_fast_prepare_state_tokenizer_processor_step")
@dataclass
class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI0Fast")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
"""
return features
def make_pi0_fast_pre_post_processors(
config: PI0FastConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0Fast policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name=config.text_tokenizer_name,
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
ActionTokenizerProcessorStep(
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name=config.text_tokenizer_name,
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .configuration_pi0_fast import PI0FastConfig from .configuration_pi05 import PI05Config
from .modeling_pi0_fast import PI0FastPolicy from .modeling_pi05 import PI05Policy
from .processor_pi0_fast import make_pi0_fast_pre_post_processors from .processor_pi05 import make_pi05_pre_post_processors
__all__ = ["PI0FastConfig", "PI0FastPolicy", "make_pi0_fast_pre_post_processors"] __all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]

View File

@@ -21,25 +21,33 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi0_fast") @PreTrainedConfig.register_subclass("pi05")
@dataclass @dataclass
class PI0FastConfig(PreTrainedConfig): class PI05Config(PreTrainedConfig):
paligemma_variant: str = "gemma_2b" paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m" action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32" dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions # Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32 max_state_dim: int = 32
max_action_dim: int = 32 max_action_dim: int = 32
max_action_tokens: int = 256
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None rtc_config: RTCConfig | None = None
@@ -53,23 +61,12 @@ class PI0FastConfig(PreTrainedConfig):
empty_cameras: int = 0 empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__` tokenizer_max_length: int = 200 # see openpi `__post_init__`
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
action_tokenizer_name: str = "physical-intelligence/fast"
temperature: float = 0.0
max_decoding_steps: int = 256
fast_skip_tokens: int = 128
# Whether to validate that decoded action tokens start with "Action: " prefix
validate_action_token_prefix: bool = True
# Whether to use KV cache for faster autoregressive decoding
use_kv_cache: bool = True
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY, "VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, # Pi0Fast uses quantiles for state "STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.MEAN_STD, # Pi0Fast uses quantiles for action "ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
} }
) )
@@ -93,6 +90,8 @@ class PI0FastConfig(PreTrainedConfig):
scheduler_decay_steps: int = 30_000 scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6 scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@@ -105,32 +104,35 @@ class PI0FastConfig(PreTrainedConfig):
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]: if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}") raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate and set up input/output features.""" """Validate and set up input/output features."""
for i in range(self.empty_cameras): for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}" key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature( empty_camera = PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution shape=(3, *self.image_resolution), # Use configured image resolution
) )
self.input_features[key] = empty_camera self.input_features[key] = empty_camera
if OBS_STATE not in self.input_features: if "observation.state" not in self.input_features:
state_feature = PolicyFeature( state_feature = PolicyFeature(
type=FeatureType.STATE, type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim shape=(self.max_state_dim,), # Padded to max_state_dim
) )
self.input_features[OBS_STATE] = state_feature self.input_features["observation.state"] = state_feature
if ACTION not in self.output_features: if "action" not in self.output_features:
action_feature = PolicyFeature( action_feature = PolicyFeature(
type=FeatureType.ACTION, type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim shape=(self.max_action_dim,), # Padded to max_action_dim
) )
self.output_features[ACTION] = action_feature self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig: def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig( return AdamWConfig(

View File

@@ -40,17 +40,8 @@ else:
GemmaForCausalLM = None GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None PaliGemmaForConditionalGeneration = None
# VideoPrism imports for video encoding
try:
from lerobot.policies.videovla.videoprism import VideoPrismVideoProcessor, VideoPrismVisionModel
_videoprism_available = True
except ImportError:
_videoprism_available = False
VideoPrismVideoProcessor = None
VideoPrismVisionModel = None
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.videovla.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05VideoConfig from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import ( from lerobot.utils.constants import (
@@ -298,60 +289,6 @@ def compute_layer_complete(
return outputs_embeds return outputs_embeds
class PerceiverResampler(nn.Module):
"""Perceiver Resampler to reduce video tokens via cross-attention.
This module uses learnable query tokens that cross-attend to the video tokens,
effectively reducing the sequence length while preserving important information.
Args:
dim: Hidden dimension of the input/output features
num_latents: Number of learnable query tokens (output sequence length)
num_heads: Number of attention heads
"""
def __init__(self, dim: int = 768, num_latents: int = 128, num_heads: int = 8):
super().__init__()
self.num_latents = num_latents
self.dim = dim
# Learnable query tokens
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# Cross-attention layer
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
batch_first=True,
)
# Layer norms for queries and key-values
self.ln_q = nn.LayerNorm(dim)
self.ln_kv = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input video tokens of shape (B, N, D) where N can be large (e.g., 4096)
Returns:
Resampled tokens of shape (B, num_latents, D)
"""
B, N, D = x.shape
# Expand learnable latents to batch size
latents = self.latents.unsqueeze(0).expand(B, -1, -1) # (B, num_latents, D)
# Apply layer norms
q = self.ln_q(latents)
kv = self.ln_kv(x)
# Cross-attention: queries attend to video tokens
out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D)
return out
class GemmaConfig: # see openpi `gemma.py: Config` class GemmaConfig: # see openpi `gemma.py: Config`
"""Configuration for Gemma model variants.""" """Configuration for Gemma model variants."""
@@ -400,14 +337,10 @@ class PaliGemmaWithExpertModel(
use_adarms=None, use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16", precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE, image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
): ):
if use_adarms is None: if use_adarms is None:
use_adarms = [False, False] use_adarms = [False, False]
super().__init__() super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001 vlm_config_hf._vocab_size = 257152 # noqa: SLF001
@@ -448,7 +381,6 @@ class PaliGemmaWithExpertModel(
self.gemma_expert.model.embed_tokens = None self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision) self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16": if precision == "bfloat16":
@@ -472,23 +404,6 @@ class PaliGemmaWithExpertModel(
if any(selector in name for selector in params_to_keep_float32): if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32) param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor): def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image) return self.paligemma.model.get_image_features(image)
@@ -523,8 +438,8 @@ class PaliGemmaWithExpertModel(
inputs_embeds=inputs_embeds[1], inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
use_cache=False, past_key_values=past_key_values,
past_key_values=None, #jadechoghari use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None, adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
) )
suffix_output = suffix_output.last_hidden_state suffix_output = suffix_output.last_hidden_state
@@ -597,7 +512,7 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model.""" """Core PI05 PyTorch model."""
def __init__(self, config: PI05VideoConfig, rtc_processor: RTCProcessor | None = None): def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.rtc_processor = rtc_processor self.rtc_processor = rtc_processor
@@ -616,8 +531,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_adarms=[False, True], use_adarms=[False, True],
precision=config.dtype, precision=config.dtype,
image_size=config.image_resolution[0], image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
) )
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -629,47 +542,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Initialize gradient checkpointing flag # Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False self.gradient_checkpointing_enabled = False
# Initialize VideoPrism video encoder if enabled
self.video_encoder = None
self.video_processor = None
self.video_proj = None
self.video_resampler = None
if config.use_video_encoder:
if not _videoprism_available:
raise ImportError(
"VideoPrism is not available. Please install the required dependencies."
)
logging.info(f"Initializing VideoPrism video encoder: {config.videoprism_model_name}")
self.video_processor = VideoPrismVideoProcessor.from_pretrained(config.videoprism_model_name)
self.video_encoder = VideoPrismVisionModel.from_pretrained(
config.videoprism_model_name,
torch_dtype=torch.bfloat16 if config.dtype == "bfloat16" else torch.float32,
attn_implementation="sdpa",
)
# Get the hidden size from VideoPrism config (default is 768 for base model)
video_hidden_size = self.video_encoder.config.hidden_size
# Initialize Perceiver Resampler to reduce video tokens (e.g., 4096 -> 128)
self.video_resampler = PerceiverResampler(
dim=video_hidden_size,
num_latents=config.video_num_latents,
num_heads=config.video_resampler_num_heads,
)
logging.info(
f"Initialized video resampler: {video_hidden_size}D, "
f"{config.video_num_latents} latents, {config.video_resampler_num_heads} heads"
)
# Project video embeddings to PaliGemma's hidden size
self.video_proj = nn.Linear(video_hidden_size, paligemma_config.width)
# Freeze video encoder if requested
if config.freeze_video_encoder:
self.video_encoder.eval()
for param in self.video_encoder.parameters():
param.requires_grad = False
logging.info("Video encoder weights are frozen")
# Compile model if requested # Compile model if requested
if config.compile_model: if config.compile_model:
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
@@ -679,13 +551,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
# try: try:
# from transformers.models.siglip import check from transformers.models.siglip import check
# if not check.check_whether_transformers_replace_is_installed_correctly(): if not check.check_whether_transformers_replace_is_installed_correctly():
# raise ValueError(msg) raise ValueError(msg)
# except ImportError: except ImportError:
# raise ValueError(msg) from None raise ValueError(msg) from None
def gradient_checkpointing_enable(self): def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization.""" """Enable gradient checkpointing for memory optimization."""
@@ -718,51 +590,51 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Helper method to prepare 4D attention masks for transformer.""" """Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :] att_2d_masks_4d = att_2d_masks[:, None, :, :]
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
def shift_padding_side(
self,
tokens: torch.Tensor,
ar_mask: torch.Tensor,
padding_mask: torch.Tensor,
loss_mask: torch.Tensor,
targets: torch.Tensor,
token_type_ids: torch.Tensor,
padding_side: str = "right",
) -> tuple[torch.Tensor]:
if padding_side not in ["right", "left"]:
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
def sample_noise(self, shape, device): new_tokens = torch.empty_like(tokens)
return torch.normal( new_ar_masks = torch.empty_like(ar_mask)
mean=0.0, new_padding_mask = torch.empty_like(padding_mask)
std=1.0, new_loss_mask = torch.empty_like(loss_mask)
size=shape, new_targets = torch.empty_like(targets)
dtype=torch.float32, new_token_type_ids = torch.empty_like(token_type_ids)
device=device, batch_size = tokens.shape[0]
) for i in range(batch_size):
padding_indices = torch.where(padding_mask[i] == 0)[0]
def sample_time(self, bsize, device): non_padding_indices = torch.where(padding_mask[i] == 1)[0]
time_beta = sample_beta( if padding_side == "left":
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
) else:
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
return time.to(dtype=torch.float32, device=device) new_tokens[i] = tokens[i].index_select(0, new_indices)
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
new_targets[i] = targets[i].index_select(0, new_indices)
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
def embed_prefix( def embed_prefix(
self, images, img_masks, tokens, masks, video_emb: torch.Tensor | None = None self, images, img_masks, tokens, attention_mask, padded_mask
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP, optional video with VideoPrism, and language tokens with embedding layer. """Embed images with SigLIP and language tokens with embedding layer."""
Args:
images: List of image tensors [B, C, H, W]
img_masks: List of image masks [B]
tokens: Language tokens [B, seq_len]
masks: Language attention masks [B, seq_len]
video_emb: Optional video embeddings from VideoPrism [B, num_video_tokens, hidden_dim]
Returns:
Tuple of (embeddings, pad_masks, att_masks)
"""
embs = [] embs = []
pad_masks = [] pad_masks = []
att_masks = [] att_masks = []
# Process video embeddings first (if available)
if video_emb is not None:
bsize, num_video_tokens, _ = video_emb.shape
embs.append(video_emb)
# Video tokens are always valid
video_mask = torch.ones(bsize, num_video_tokens, dtype=torch.bool, device=video_emb.device)
pad_masks.append(video_mask)
att_masks += [0] * num_video_tokens
# Process images # Process images
for img, img_mask in zip(images, img_masks, strict=True): for img, img_mask in zip(images, img_masks, strict=True):
@@ -784,351 +656,80 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
lang_emb = self._apply_checkpoint(lang_embed_func, tokens) lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb) embs.append(lang_emb)
pad_masks.append(masks) pad_masks.append(padded_mask)
num_lang_embs = lang_emb.shape[1] num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs # att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1) embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1) pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = torch.cat(
[att_masks, attention_mask], dim=1
)
bsize = pad_masks.shape[0] bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks)) att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks return embs, pad_masks, att_masks
def embed_video(self, video_frames: torch.Tensor) -> torch.Tensor: def forward(self, images, img_masks, tokens, masks) -> Tensor:
"""Embed video frames using VideoPrism encoder. """Do a full training forward pass and compute the loss."""
# tokens will contain the tokenized actions as well insisde
Args: embs, pad_masks, att_masks = self.embed_prefix(images, img_masks, tokens, masks)
video_frames: Tensor of shape [B, T, C, H, W] where T is the number of frames.
Expected to be normalized to [0, 1].
Returns:
Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to
PaliGemma's hidden dimension.
"""
if self.video_encoder is None:
raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.")
device = video_frames.device
dtype = video_frames.dtype
# Move video encoder to the same device if needed
if next(self.video_encoder.parameters()).device != device:
self.video_encoder = self.video_encoder.to(device)
# VideoPrism expects pixel values in [0, 1] range and shape [B, T, C, H, W]
# Resize frames to VideoPrism expected size if needed
B, T, C, H, W = video_frames.shape
target_size = self.config.videoprism_image_size
if H != target_size or W != target_size:
# Resize each frame
video_frames = video_frames.view(B * T, C, H, W)
video_frames = F.interpolate(
video_frames,
size=(target_size, target_size),
mode="bilinear",
align_corners=False,
)
video_frames = video_frames.view(B, T, C, target_size, target_size)
# Convert to the expected dtype for the video encoder
video_encoder_dtype = next(self.video_encoder.parameters()).dtype
video_frames = video_frames.to(dtype=video_encoder_dtype)
# Run through VideoPrism
with torch.set_grad_enabled(not self.config.freeze_video_encoder):
if self.config.freeze_video_encoder:
self.video_encoder.eval()
video_outputs = self.video_encoder(pixel_values_videos=video_frames)
# Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768])
video_embeddings = video_outputs.last_hidden_state
# Convert to working dtype
video_embeddings = video_embeddings.to(dtype=dtype)
# Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128)
# This uses cross-attention from learnable queries to the video tokens
video_embeddings = self.video_resampler(video_embeddings)
# Shape: [B, num_latents, hidden_size] (e.g., [B, 128, 768])
# Project to PaliGemma's hidden dimension
video_embeddings = self.video_proj(video_embeddings)
return video_embeddings
def embed_suffix(self, noisy_actions, timestep):
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Embed timestep using sine-cosine positional encoding
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.action_in_proj.out_features,
min_period=self.config.min_period,
max_period=self.config.max_period,
device=timestep.device,
)
time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
def action_proj_func(noisy_actions):
return self.action_in_proj(noisy_actions)
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
def time_mlp_func(time_emb):
x = self.time_mlp_in(time_emb)
x = F.silu(x)
x = self.time_mlp_out(x)
return F.silu(x)
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
action_time_emb = action_emb
adarms_cond = time_emb
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, adarms_cond
def forward(
self, images, img_masks, tokens, masks, actions, noise=None, time=None, video_frames=None
) -> Tensor:
"""Do a full training forward pass and compute the loss.
Args:
images: List of image tensors [B, C, H, W]
img_masks: List of image masks [B]
tokens: Language tokens [B, seq_len]
masks: Language attention masks [B, seq_len]
actions: Ground truth actions [B, chunk_size, action_dim]
noise: Optional noise tensor for flow matching
time: Optional time tensor for flow matching
video_frames: Optional video frames [B, T, C, H, W] for video encoding
"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# Embed video if provided and video encoder is available
video_emb = None
if video_frames is not None and self.video_encoder is not None:
video_emb = self.embed_video(video_frames)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, tokens, masks, video_emb=video_emb
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
# will add loss for ce token prediction here
att_2d_masks = make_att_2d_masks(pad_masks, att_masks) att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1 position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
def action_out_proj_func(suffix_out):
return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
video_frames=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action.
Args:
images: List of image tensors [B, C, H, W]
img_masks: List of image masks [B]
tokens: Language tokens [B, seq_len]
masks: Language attention masks [B, seq_len]
noise: Optional noise tensor
num_steps: Number of denoising steps
video_frames: Optional video frames [B, T, C, H, W] for video encoding
"""
if num_steps is None:
num_steps = self.config.num_inference_steps
bsize = tokens.shape[0]
device = tokens.device
if noise is None:
# Sample noise with padded dimension as expected by action_in_proj
actions_shape = (
bsize,
self.config.chunk_size,
self.config.max_action_dim,
) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device)
# Embed video if provided and video encoder is available
video_emb = None
if video_frames is not None and self.video_encoder is not None:
video_emb = self.embed_video(video_frames)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, tokens, masks, video_emb=video_emb
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward( outputs = self.paligemma_with_expert.paligemma.forward(
attention_mask=prefix_att_2d_masks_4d, input_ids=None,
position_ids=prefix_position_ids, token_type_ids=None,
past_key_values=None, attention_mask=att_2d_masks_4d,
inputs_embeds=[prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
return x_t
def denoise_step(
self,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=None,
inputs_embeds=[None, suffix_embs], inputs_embeds=embs,
use_cache=False, use_cache=False,
adarms_cond=[None, adarms_cond], labels=None,
) )
logits = outputs.logits
loss_fct = nn.CrossEntropyLoss(reduction="none")
suffix_out = outputs_embeds[1] device = embs.device
suffix_out = suffix_out[:, -self.config.chunk_size :] # Shift left for next-step prediction
suffix_out = suffix_out.to(dtype=torch.float32) logits = logits[:, :-1, :]
return self.action_out_proj(suffix_out) targets = targets[:, 1:].to(device) # Shift targets
loss_mask = loss_masks[:, 1:].to(device) # Ensure correct shape
# Compute per-token loss
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
# Compute per-token loss
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
# Apply loss mask
token_loss = token_loss * loss_mask.reshape(-1)
# Compute final loss
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
# Return loss dictionary
return loss
class PI05VideoPolicy(PreTrainedPolicy): class PI05Policy(PreTrainedPolicy):
"""PI05 Video Policy for LeRobot with optional video encoding support.""" """PI05 Policy for LeRobot."""
config_class = PI05VideoConfig config_class = PI05Config
name = "pi05_video" name = "pi05"
def __init__( def __init__(
self, self,
config: PI05VideoConfig, config: PI05Config,
**kwargs, **kwargs,
): ):
""" """
@@ -1353,33 +954,11 @@ class PI05VideoPolicy(PreTrainedPolicy):
def _rtc_enabled(self) -> bool: def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _get_video_camera_key(self) -> str | None:
"""Get the camera key to use for video encoding.
Returns the configured video_encoder_camera_key if set,
otherwise returns the first image feature key.
"""
if not self.config.use_video_encoder:
return None
if self.config.video_encoder_camera_key is not None:
return self.config.video_encoder_camera_key
# Default to first image feature (image_features is a dict)
if self.config.image_features:
return next(iter(self.config.image_features.keys()))
return None
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model. """Preprocess images for the model.
Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
When video encoding is enabled:
- The video camera is skipped (processed separately by video encoder)
- Other cameras with temporal dimension have only the current frame extracted
""" """
images = [] images = []
img_masks = [] img_masks = []
@@ -1387,17 +966,10 @@ class PI05VideoPolicy(PreTrainedPolicy):
# Get device from model parameters # Get device from model parameters
device = next(self.parameters()).device device = next(self.parameters()).device
# Determine which camera is used for video encoding (to skip it)
video_camera_key = self._get_video_camera_key()
present_img_keys = [key for key in self.config.image_features if key in batch] present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch] missing_img_keys = [key for key in self.config.image_features if key not in batch]
# Filter out the video camera key if video encoding is enabled if len(present_img_keys) == 0:
if video_camera_key is not None and video_camera_key in present_img_keys:
present_img_keys = [k for k in present_img_keys if k != video_camera_key]
if len(present_img_keys) == 0 and video_camera_key is None:
raise ValueError( raise ValueError(
f"All image features are missing from the batch. At least one expected. " f"All image features are missing from the batch. At least one expected. "
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
@@ -1415,11 +987,6 @@ class PI05VideoPolicy(PreTrainedPolicy):
if img.dtype != torch.float32: if img.dtype != torch.float32:
img = img.to(torch.float32) img = img.to(torch.float32)
# Handle temporal dimension: if [B, T, C, H, W], extract current frame (last one)
if img.ndim == 5:
# Extract the last frame (current observation at index -1)
img = img[:, -1] # [B, T, C, H, W] -> [B, C, H, W]
# from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
@@ -1446,99 +1013,13 @@ class PI05VideoPolicy(PreTrainedPolicy):
# Create image features not present in the batch as fully 0 padded images # Create image features not present in the batch as fully 0 padded images
for _num_empty_cameras in range(len(missing_img_keys)): for _num_empty_cameras in range(len(missing_img_keys)):
if len(images) > 0: img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
img = torch.ones_like(images[-1]) * -1 # Padded with -1 for SigLIP mask = torch.zeros_like(mask) # Mask is zero for empty cameras
mask = torch.zeros_like(img_masks[-1]) # Mask is zero for empty cameras
else:
# No images processed yet, create placeholder
bsize = next(iter(batch.values())).shape[0]
img = torch.ones(
bsize, 3, *self.config.image_resolution, dtype=torch.float32, device=device
) * -1
mask = torch.zeros(bsize, dtype=torch.bool, device=device)
images.append(img) images.append(img)
img_masks.append(mask) img_masks.append(mask)
return images, img_masks return images, img_masks
def _preprocess_video(self, batch: dict[str, Tensor]) -> Tensor | None:
"""Preprocess video frames for the video encoder.
When image_observation_delta_indices is set (for video encoding), the batch will contain
images with shape [B, T, C, H, W] where T is the number of frames.
This method extracts and preprocesses these frames for VideoPrism.
Handles frame padding at episode start when fewer than video_num_frames are available:
- "repeat": Repeat the first available frame to fill missing frames
- "zero": Use zero-padded frames for missing frames
Args:
batch: Training batch potentially containing multi-frame observations.
Returns:
Video frames tensor of shape [B, T, C, H, W] normalized to [0, 1],
or None if video encoding is not enabled.
"""
if not self.config.use_video_encoder:
return None
device = next(self.parameters()).device
# Get the video camera key
video_camera_key = self._get_video_camera_key()
if video_camera_key is None or video_camera_key not in batch:
return None
img = batch[video_camera_key]
# Check if we have temporal dimension (video frames)
if img.ndim == 4:
# Single frame [B, C, H, W] - expand to video by repeating
B, C, H, W = img.shape
if self.config.video_padding_mode == "repeat":
video_frames = img.unsqueeze(1).expand(B, self.config.video_num_frames, C, H, W)
else: # zero padding
video_frames = torch.zeros(
B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device
)
video_frames[:, -1] = img # Put current frame at the end
elif img.ndim == 5:
# Multiple frames [B, T, C, H, W]
video_frames = img
B, T, C, H, W = video_frames.shape
# Handle case where we have fewer frames than expected (episode start)
if T < self.config.video_num_frames:
num_missing = self.config.video_num_frames - T
if self.config.video_padding_mode == "repeat":
# Repeat the first frame to fill missing frames at the beginning
first_frame = video_frames[:, 0:1] # [B, 1, C, H, W]
padding = first_frame.expand(B, num_missing, C, H, W)
video_frames = torch.cat([padding, video_frames], dim=1)
else: # zero padding
# Zero-pad at the beginning
padding = torch.zeros(
B, num_missing, C, H, W, dtype=video_frames.dtype, device=video_frames.device
)
video_frames = torch.cat([padding, video_frames], dim=1)
else:
logging.warning(f"Unexpected image shape for video camera: {img.shape}")
return None
# Ensure tensor is on the same device
if video_frames.device != device:
video_frames = video_frames.to(device)
# Ensure float32 dtype
if video_frames.dtype != torch.float32:
video_frames = video_frames.to(torch.float32)
# Video frames should be in [0, 1] range for VideoPrism
# LeRobot images are already in [0, 1] range
return video_frames
def prepare_action(self, batch): def prepare_action(self, batch):
"""Pad action""" """Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim) actions = pad_vector(batch[ACTION], self.config.max_action_dim)
@@ -1570,13 +1051,10 @@ class PI05VideoPolicy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch) images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Preprocess video frames if video encoding is enabled
video_frames = self._preprocess_video(batch)
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions( # now we must call .generate() method on the model
images, img_masks, tokens, masks, video_frames=video_frames, **kwargs # then detoknize
) # actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension # Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1599,11 +1077,8 @@ class PI05VideoPolicy(PreTrainedPolicy):
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
# Preprocess video frames if video encoding is enabled
video_frames = self._preprocess_video(batch)
# Compute loss (no separate state needed for PI05) # Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions, video_frames=video_frames) losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions # Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1623,14 +1098,3 @@ class PI05VideoPolicy(PreTrainedPolicy):
loss = losses.mean() loss = losses.mean()
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict
def _get_default_peft_targets(self) -> dict[str, any]:
"""Return default PEFT target modules for PI0.5 fine-tuning."""
common_projections = (
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
)
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
return {
"target_modules": target_modules,
"modules_to_save": [],
}

Some files were not shown because too many files have changed in this diff Show More