Compare commits

...

13 Commits

Author SHA1 Message Date
Steven Palma
224be5be9a Merge branch 'main' into feat/add_macos_ci 2025-10-14 18:52:04 +02:00
Steven Palma
a6ff3cfebb chore(deps): libero dep pointing to main (#2201) 2025-10-14 18:19:49 +02:00
Jade Choghari
271d92dcaa feat(sim): add metaworld env (#2088)
* add metaworld

* smol update

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* update design

* Update src/lerobot/envs/metaworld.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* update

* small changes

* iterate on review

* small fix

* small fix

* add docs

* update doc

* add better gif

* smol doc fix

* updage gymnasium

* add note

* depreciate gym-xarm

* more changes

* update doc

* comply with mypy

* more fixes

* update readme

* precommit

* update pusht

* add pusht instead

* changes

* style

* add changes

* update

* revert

* update v2

* chore(envs): move metaworld config to its own file + remove comments + simplify _format_raw_obs (#2200)

* update final changes

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-14 17:21:18 +02:00
Michel Aractingi
8e940bf361 Feat/expand add features (#2202)
* make add_feature take multiple features at a time and rename to add_features

* - New function: modify_features that was a combination of remove features and add features.
 - This function is important for when we want to add a feature and remove another so we can do it in one time to avoid copying and creating the dataset multiple times
2025-10-14 16:19:50 +02:00
Steven Palma
6e8be57eb2 chore(policies): deprecate pi0fast (#2203) 2025-10-14 16:00:42 +02:00
Francesco Capuano
723013c71b feat(scripts): Introduce build_inference_frame/make_robot_action util to easily allow API-based Inference (#2143)
* fix: expose a function explicitly building a frame for inference

* fix: first make dataset frame, then make ready for inference

* fix: reducing reliance on lerobot record for policy's ouptuts too

* fix: encapsulating squeezing out + device handling from predict action

* fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion)

* fix(policies): right utils signature + docstrings (#2198)

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-14 15:47:32 +02:00
Steven Palma
bf6ac5e110 fix(datasets): conversion script function naming (#2199)
Co-authored-by: gagalo123 <bamianweifen@gmail.com>
2025-10-14 14:36:32 +02:00
Steven Palma
3ce5bcf24d feat(deps): add setuptools dependency (#2187) 2025-10-14 14:00:52 +02:00
Francesco Capuano
6f5bb4d4a4 fix outdated example in docs (#2182)
* fix outdated example

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* Update docs/source/il_robots.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-13 16:43:23 +02:00
Francesco Capuano
f29311ccb0 fix: very minor fix but hey devil is in details (#2168)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-10-13 10:44:53 +02:00
Michel Aractingi
0c79cf8f4e Add missing finalize calls in example (#2175)
- add missing calls to dataset.finalize in the example recording scripts
- add section in the dataset docs on calling dataset.finalize
2025-10-11 21:15:43 +02:00
Steven Palma
67269e33a5 ci: add more env flags 2025-10-08 17:14:20 +02:00
Steven Palma
66936f278f feat(ci): add macos runner testing 2025-10-08 14:58:55 +02:00
42 changed files with 1241 additions and 1540 deletions

View File

@@ -57,7 +57,11 @@ jobs:
# It runs everytime we commit to a PR or push to main
fast-pytest-tests:
name: Fast Pytest Tests
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
env:
MUJOCO_GL: egl
steps:
@@ -67,12 +71,21 @@ jobs:
lfs: true
# TODO(Steven): Evaluate the need of these dependencies
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential git \
curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
# Add ffmpeg@7 paths for subsequent steps
echo "PATH=/opt/homebrew/opt/ffmpeg@7/bin:$PATH" >> $GITHUB_ENV
echo "LDFLAGS=-L/opt/homebrew/opt/ffmpeg@7/lib" >> $GITHUB_ENV
echo "CPPFLAGS=-I/opt/homebrew/opt/ffmpeg@7/include" >> $GITHUB_ENV
echo "PKG_CONFIG_PATH=/opt/homebrew/opt/ffmpeg@7/lib/pkgconfig" >> $GITHUB_ENV
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:

View File

@@ -51,7 +51,11 @@ jobs:
# It runs everytime a PR is approved or a push to main
full-tests:
name: Full Tests
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
if: |
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved') ||
github.event_name == 'push' ||
@@ -64,11 +68,16 @@ jobs:
lfs: true
persist-credentials: false
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]

View File

@@ -120,7 +120,11 @@ jobs:
test-release:
name: Test Release
needs: [build-and-publish]
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
permissions:
contents: read
env:
@@ -130,11 +134,16 @@ jobs:
with:
lfs: true
persist-credentials: false
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:

View File

@@ -42,7 +42,11 @@ jobs:
# This job runs the E2E tests + pytest with all unbound extras
full-tests:
name: Full Unbound Tests
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
env:
MUJOCO_GL: egl
steps:
@@ -51,11 +55,16 @@ jobs:
lfs: true
persist-credentials: false
- name: Install apt dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
brew update && brew install git geos portaudio ffmpeg@7
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
fi
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]

View File

@@ -72,7 +72,6 @@ post it.
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
environments ([aloha](https://github.com/huggingface/gym-aloha),
[xarm](https://github.com/huggingface/gym-xarm),
[pusht](https://github.com/huggingface/gym-pusht))
and follow the same api design.

View File

@@ -119,10 +119,9 @@ test-tdmpc-ete-train:
--policy.type=tdmpc \
--policy.device=$(DEVICE) \
--policy.push_to_hub=false \
--env.type=xarm \
--env.task=XarmLift-v0 \
--env.type=pusht \
--env.episode_length=5 \
--dataset.repo_id=lerobot/xarm_lift_medium \
--dataset.repo_id=lerobot/pusht_image \
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
@@ -140,9 +139,10 @@ test-tdmpc-ete-eval:
lerobot-eval \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=xarm \
--env.type=pusht \
--env.episode_length=5 \
--env.task=XarmLift-v0 \
--env.observation_height=96 \
--env.observation_width=96 \
--eval.n_episodes=1 \
--eval.batch_size=1

View File

@@ -7,8 +7,6 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: il_sim
title: Imitation Learning in Sim
- local: cameras
title: Cameras
- local: integrate_hardware
@@ -37,9 +35,15 @@
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
title: "Policies"
- sections:
- local: il_sim
title: Imitation Learning in Sim
- local: libero
title: Using Libero
title: "Policies"
- local: metaworld
title: Using MetaWorld
title: "Simulation"
- sections:
- local: introduction_processors
title: Introduction to Robot Processors

View File

@@ -513,13 +513,14 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.record import record_loop
from lerobot.policies.factory import make_processor
NUM_EPISODES = 5
FPS = 30
@@ -562,7 +563,7 @@ init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_processor(
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,

View File

@@ -91,7 +91,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c
### Simulations
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
Example:
```bash

View File

@@ -279,3 +279,36 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DAT
- Aggregates parquet files: `episode-0000.parquet`, `episode-0001.parquet`, … → **`file-0000.parquet`**, …
- Aggregates mp4 files: `episode-0000.mp4`, `episode-0001.mp4`, … → **`file-0000.mp4`**, …
- Updates `meta/episodes/*` (chunked Parquet) with perepisode lengths, tasks, and byte/frame offsets.
## Common Issues
### Always call `finalize()` before pushing
When creating or recording datasets, you **must** call `dataset.finalize()` to properly close parquet writers. See the [PR #1903](https://github.com/huggingface/lerobot/pull/1903) for more details.
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Create dataset and record episodes
dataset = LeRobotDataset.create(...)
for episode in range(num_episodes):
# Record frames
for frame in episode_data:
dataset.add_frame(frame)
dataset.save_episode()
# Call finalize() when done recording and before push_to_hub()
dataset.finalize() # Closes parquet writers, writes metadata footers
dataset.push_to_hub()
```
**Why is this necessary?**
Dataset v3.0 uses incremental parquet writing with buffered metadata for efficiency. The `finalize()` method:
- Flushes any buffered episode metadata to disk
- Closes parquet writers to write footer metadata, otherwise the parquet files will be corrupt
- Ensures the dataset is valid for loading
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.

View File

@@ -137,7 +137,7 @@ The finetuned model can be found here:
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
```bash
python src/lerobot/scripts/eval.py \
lerobot-eval \
--output_dir=/logs/ \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \

80
docs/source/metaworld.mdx Normal file
View File

@@ -0,0 +1,80 @@
# Meta-World
Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics.
- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897)
- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld)
![MetaWorld MT10 demo](https://meta-world.github.io/figures/ml45.gif)
## Why Meta-World matters
- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure.
- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task.
- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes.
- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions.
## What it enables in LeRobot
In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward:
- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`.
- This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting).
- MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency.
- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals.
## Quick start, train a SmolVLA policy on Meta-World
Example command to train a SmolVLA policy on a subset of tasks:
```bash
lerobot-train \
--policy.type=smolvla \
--policy.repo_id=${HF_USER}/metaworld-test \
--policy.load_vlm_weights=true \
--dataset.repo_id=lerobot/metaworld_mt50 \
--env.type=metaworld \
--env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
--output_dir=./outputs/ \
--steps=100000 \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval_freq=1000
```
Notes:
- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`).
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
- **Gymnasium Assertion Error**: if you encounter an error like
`AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version.
We recommend using:
```bash
pip install "gymnasium==1.1.0"
```
to ensure proper compatibility.
## Quick start — evaluate a trained policy
To evaluate a trained policy on the Meta-World medium difficulty split:
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=medium \
--eval.batch_size=1 \
--eval.n_episodes=2
```
This will run episodes and return per-task success rates using the standard Meta-World evaluation keys.
## Practical tips
- If you care about generalization, run on the full MT50 suite — its intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context.
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.

View File

@@ -30,9 +30,10 @@ Usage:
import numpy as np
from lerobot.datasets.dataset_tools import (
add_feature,
add_features,
delete_episodes,
merge_datasets,
modify_features,
remove_feature,
split_dataset,
)
@@ -57,50 +58,56 @@ def main():
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
print("\n3. Adding a reward feature...")
print("\n3. Adding features...")
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
dataset_with_reward = add_feature(
dataset,
feature_name="reward",
feature_values=reward_values,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
},
repo_id="lerobot/pusht_with_reward",
)
def compute_success(row_dict, episode_index, frame_index):
episode_length = 10
return float(frame_index >= episode_length - 10)
dataset_with_success = add_feature(
dataset_with_reward,
feature_name="success",
feature_values=compute_success,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
dataset_with_features = add_features(
dataset,
features={
"reward": (
reward_values,
{"dtype": "float32", "shape": (1,), "names": None},
),
"success": (
compute_success,
{"dtype": "float32", "shape": (1,), "names": None},
),
},
repo_id="lerobot/pusht_with_reward_and_success",
repo_id="lerobot/pusht_with_features",
)
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
print(f"New features: {list(dataset_with_features.meta.features.keys())}")
print("\n4. Removing the success feature...")
dataset_cleaned = remove_feature(
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
)
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
print("\n5. Merging train and val splits back together...")
print("\n5. Using modify_features to add and remove features simultaneously...")
dataset_modified = modify_features(
dataset_with_features,
add_features={
"discount": (
np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
{"dtype": "float32", "shape": (1,), "names": None},
),
},
remove_features="reward",
repo_id="lerobot/pusht_modified",
)
print(f"Modified features: {list(dataset_modified.meta.features.keys())}")
print("\n6. Merging train and val splits back together...")
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
print("\n6. Complex workflow example...")
print("\n7. Complex workflow example...")
if len(dataset.meta.camera_keys) > 1:
camera_to_remove = dataset.meta.camera_keys[0]

View File

@@ -133,4 +133,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -130,4 +130,6 @@ robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -194,4 +194,6 @@ for episode_idx in range(NUM_EPISODES):
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -200,4 +200,6 @@ log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -362,6 +362,8 @@ def port_droid(
lerobot_dataset.save_episode()
logging.info("Save_episode")
lerobot_dataset.finalize()
if push_to_hub:
lerobot_dataset.push_to_hub(
# Add openx tag, since it belongs to the openx collection of datasets

View File

@@ -195,4 +195,6 @@ for episode_idx in range(NUM_EPISODES):
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -199,4 +199,6 @@ log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -64,6 +64,7 @@ dependencies = [
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
# Core dependencies
"setuptools>=71.0.0,<81.0.0",
"cmake>=3.29.0.1,<4.2.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
@@ -79,7 +80,7 @@ dependencies = [
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
"gymnasium>=1.0.0",
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
# Support dependencies
@@ -132,11 +133,10 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.1,<0.2.0"]
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
xarm = ["gym-xarm>=0.1.1,<0.2.0"]
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
metaworld = ["metaworld>=3.0.0"]
# All
all = [
@@ -156,9 +156,9 @@ all = [
"lerobot[video_benchmark]",
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[xarm]",
"lerobot[phone]",
"lerobot[libero]",
"lerobot[metaworld]",
]
[project.scripts]

View File

@@ -57,7 +57,6 @@ available_tasks_per_env = {
"AlohaTransferCube-v0",
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
}
available_envs = list(available_tasks_per_env.keys())
@@ -75,16 +74,6 @@ available_datasets_per_env = {
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
# coupled with tests.
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
"xarm": [
"lerobot/xarm_lift_medium",
"lerobot/xarm_lift_medium_replay",
"lerobot/xarm_push_medium",
"lerobot/xarm_push_medium_replay",
"lerobot/xarm_lift_medium_image",
"lerobot/xarm_lift_medium_replay_image",
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
}
available_real_world_datasets = [
@@ -195,7 +184,6 @@ available_motors = [
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion", "vqbet"],
"xarm": ["tdmpc"],
"koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"],
}

View File

@@ -28,8 +28,10 @@ import shutil
from collections.abc import Callable
from pathlib import Path
import datasets
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
@@ -43,7 +45,6 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
get_parquet_file_size_in_mb,
load_episodes,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
@@ -268,39 +269,79 @@ def merge_datasets(
return merged_dataset
def add_feature(
def modify_features(
dataset: LeRobotDataset,
feature_name: str,
feature_values: np.ndarray | torch.Tensor | Callable,
feature_info: dict,
add_features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]] | None = None,
remove_features: str | list[str] | None = None,
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Add a new feature to a LeRobotDataset.
"""Modify a LeRobotDataset by adding and/or removing features in a single pass.
This is the most efficient way to modify features, as it only copies the dataset once
regardless of how many features are being added or removed.
Args:
dataset: The source LeRobotDataset.
feature_name: Name of the new feature.
feature_values: Either:
- Array/tensor of shape (num_frames, ...) with values for each frame
- Callable that takes (frame_dict, episode_index, frame_index) and returns feature value
feature_info: Dictionary with feature metadata (dtype, shape, names).
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
remove_features: Optional feature name(s) to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
New dataset with features modified.
Example:
new_dataset = modify_features(
dataset,
add_features={
"reward": (reward_array, {"dtype": "float32", "shape": [1], "names": None}),
},
remove_features=["old_feature"],
output_dir="./output",
)
"""
if feature_name in dataset.meta.features:
raise ValueError(f"Feature '{feature_name}' already exists in dataset")
if add_features is None and remove_features is None:
raise ValueError("Must specify at least one of add_features or remove_features")
remove_features_list: list[str] = []
if remove_features is not None:
remove_features_list = [remove_features] if isinstance(remove_features, str) else remove_features
if add_features:
required_keys = {"dtype", "shape"}
for feature_name, (_, feature_info) in add_features.items():
if feature_name in dataset.meta.features:
raise ValueError(f"Feature '{feature_name}' already exists in dataset")
if not required_keys.issubset(feature_info.keys()):
raise ValueError(f"feature_info for '{feature_name}' must contain keys: {required_keys}")
if remove_features_list:
for name in remove_features_list:
if name not in dataset.meta.features:
raise ValueError(f"Feature '{name}' not found in dataset")
required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
if any(name in required_features for name in remove_features_list):
raise ValueError(f"Cannot remove required features: {required_features}")
if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
required_keys = {"dtype", "shape"}
if not required_keys.issubset(feature_info.keys()):
raise ValueError(f"feature_info must contain keys: {required_keys}")
new_features = dataset.meta.features.copy()
new_features[feature_name] = feature_info
if remove_features_list:
for name in remove_features_list:
new_features.pop(name, None)
if add_features:
for feature_name, (_, feature_info) in add_features.items():
new_features[feature_name] = feature_info
video_keys_to_remove = [name for name in remove_features_list if name in dataset.meta.video_keys]
remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -308,17 +349,18 @@ def add_feature(
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
use_videos=len(remaining_video_keys) > 0,
)
_copy_data_with_feature_changes(
dataset=dataset,
new_meta=new_meta,
add_features={feature_name: (feature_values, feature_info)},
add_features=add_features,
remove_features=remove_features_list if remove_features_list else None,
)
if dataset.meta.video_keys:
_copy_videos(dataset, new_meta)
if new_meta.video_keys:
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove if video_keys_to_remove else None)
new_dataset = LeRobotDataset(
repo_id=repo_id,
@@ -331,6 +373,46 @@ def add_feature(
return new_dataset
def add_features(
dataset: LeRobotDataset,
features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Add multiple features to a LeRobotDataset in a single pass.
This is more efficient than calling add_feature() multiple times, as it only
copies the dataset once regardless of how many features are being added.
Args:
dataset: The source LeRobotDataset.
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
New dataset with all features added.
Example:
features = {
"task_embedding": (task_emb_array, {"dtype": "float32", "shape": [384], "names": None}),
"cam1_embedding": (cam1_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
"cam2_embedding": (cam2_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
}
new_dataset = add_features(dataset, features, output_dir="./output", repo_id="my_dataset")
"""
if not features:
raise ValueError("No features provided")
return modify_features(
dataset=dataset,
add_features=features,
remove_features=None,
output_dir=output_dir,
repo_id=repo_id,
)
def remove_feature(
dataset: LeRobotDataset,
feature_names: str | list[str],
@@ -345,56 +427,17 @@ def remove_feature(
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
New dataset with features removed.
"""
if isinstance(feature_names, str):
feature_names = [feature_names]
for name in feature_names:
if name not in dataset.meta.features:
raise ValueError(f"Feature '{name}' not found in dataset")
required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
if any(name in required_features for name in feature_names):
raise ValueError(f"Cannot remove required features: {required_features}")
if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names}
video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys]
remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(remaining_video_keys) > 0,
)
_copy_data_with_feature_changes(
return modify_features(
dataset=dataset,
new_meta=new_meta,
add_features=None,
remove_features=feature_names,
)
if new_meta.video_keys:
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove)
new_dataset = LeRobotDataset(
output_dir=output_dir,
repo_id=repo_id,
root=output_dir,
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
)
return new_dataset
def _fractions_to_episode_indices(
total_episodes: int,
@@ -501,10 +544,7 @@ def _copy_and_reindex_data(
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
if len(dst_meta.image_keys) > 0:
to_parquet_with_hf_images(df, dst_path)
else:
df.to_parquet(dst_path, index=False)
_write_parquet(df, dst_path, dst_meta)
for ep_old_idx in episodes_to_keep:
ep_new_idx = episode_mapping[ep_old_idx]
@@ -862,6 +902,25 @@ def _copy_and_reindex_episodes_metadata(
write_stats(filtered_stats, dst_meta.root)
def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None:
"""Write DataFrame to parquet
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
"""
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
hf_features = get_hf_features_from_features(meta.features)
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
if len(meta.image_keys) > 0:
ep_dataset = embed_images(ep_dataset)
table = ep_dataset.with_format("arrow")[:]
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
writer.write_table(table)
writer.close()
def _save_data_chunk(
df: pd.DataFrame,
meta: LeRobotDatasetMetadata,
@@ -877,10 +936,7 @@ def _save_data_chunk(
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(meta.image_keys) > 0:
to_parquet_with_hf_images(df, path)
else:
df.to_parquet(path, index=False)
_write_parquet(df, path, meta)
episode_metadata = {}
for ep_idx in df["episode_index"].unique():
@@ -906,19 +962,34 @@ def _copy_data_with_feature_changes(
remove_features: list[str] | None = None,
) -> None:
"""Copy data while adding or removing features."""
file_paths = set()
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.meta.root)
# Map file paths to episode indices to extract chunk/file indices
file_to_episodes: dict[Path, set[int]] = {}
for ep_idx in range(dataset.meta.total_episodes):
file_paths.add(dataset.meta.get_data_file_path(ep_idx))
file_path = dataset.meta.get_data_file_path(ep_idx)
if file_path not in file_to_episodes:
file_to_episodes[file_path] = set()
file_to_episodes[file_path].add(ep_idx)
frame_idx = 0
for src_path in tqdm(sorted(file_paths), desc="Processing data files"):
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
# Get chunk_idx and file_idx from the source file's first episode
episodes_in_file = file_to_episodes[src_path]
first_ep_idx = min(episodes_in_file)
src_ep = dataset.meta.episodes[first_ep_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
if remove_features:
df = df.drop(columns=remove_features, errors="ignore")
if add_features:
end_idx = frame_idx + len(df)
for feature_name, (values, _) in add_features.items():
if callable(values):
feature_values = []
@@ -931,15 +1002,18 @@ def _copy_data_with_feature_changes(
feature_values.append(value)
df[feature_name] = feature_values
else:
end_idx = frame_idx + len(df)
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
df[feature_name] = feature_slice.flatten()
else:
df[feature_name] = feature_slice
frame_idx = end_idx
frame_idx = end_idx
_save_data_chunk(df, new_meta)
# Write using the preserved chunk_idx and file_idx from source
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, new_meta)
_copy_episodes_metadata_and_stats(dataset, new_meta)

View File

@@ -69,9 +69,9 @@ from lerobot.datasets.utils import (
LEGACY_TASKS_PATH,
cast_stats_to_numpy,
flatten_dict,
get_file_size_in_mb,
get_parquet_file_size_in_mb,
get_parquet_num_frames,
get_video_size_in_mb,
load_info,
update_chunk_file_indices,
write_episodes,
@@ -310,7 +310,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
episodes_metadata = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Check if adding this episode would exceed the limit

View File

@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401

View File

@@ -133,45 +133,6 @@ class PushtEnv(EnvConfig):
}
@EnvConfig.register_subclass("xarm")
@dataclass
class XarmEnv(EnvConfig):
task: str | None = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
visualization_width: int = 384
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
}
)
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@dataclass
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
@@ -306,3 +267,45 @@ class LiberoEnv(EnvConfig):
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}
@EnvConfig.register_subclass("metaworld")
@dataclass
class MetaworldEnv(EnvConfig):
task: str = "metaworld-push-v2" # add all tasks
fps: int = 80
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
multitask_eval: bool = True
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}",
"pixels/top": f"{OBS_IMAGE}",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}

View File

@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -25,8 +25,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return AlohaEnv(**kwargs)
elif env_type == "pusht":
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
else:
@@ -74,7 +72,18 @@ def make_env(
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
if cfg.task is None:
raise ValueError("MetaWorld requires a task to be specified")
return create_metaworld_envs(
task=cfg.task,
n_envs=n_envs,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
@@ -87,7 +96,7 @@ def make_env(
def _make_one():
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)])
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
# normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha"

View File

@@ -260,19 +260,23 @@ class LiberoEnv(gym.Env):
is_success = self._env.check_success()
terminated = done or is_success
info["is_success"] = is_success
info.update(
{
"task": self.task,
"task_id": self.task_id,
"done": done,
"is_success": is_success,
}
)
observation = self._format_raw_obs(raw_obs)
if done:
if terminated:
info["final_info"] = {
"task": self.task,
"task_id": self.task_id,
"done": bool(done),
"is_success": bool(is_success),
}
self.reset()
info.update(
{
"task": self.task,
"task_id": self.task_id,
"done": done,
"is_success": is_success,
}
)
truncated = False
return observation, reward, terminated, truncated, info

View File

@@ -0,0 +1,313 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections import defaultdict
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any
import gymnasium as gym
import metaworld
import metaworld.policies as policies
import numpy as np
from gymnasium import spaces
# ---- Load configuration data from the external JSON file ----
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
try:
with open(CONFIG_PATH) as f:
data = json.load(f)
except FileNotFoundError as err:
raise FileNotFoundError(
"Could not find 'metaworld_config.json'. "
"Please ensure the configuration file is in the same directory as the script."
) from err
except json.JSONDecodeError as err:
raise ValueError(
"Failed to decode 'metaworld_config.json'. Please ensure it is a valid JSON file."
) from err
# ---- Process the loaded data ----
# extract and type-check top-level dicts
task_descriptions_obj = data.get("TASK_DESCRIPTIONS")
if not isinstance(task_descriptions_obj, dict):
raise TypeError("Expected TASK_DESCRIPTIONS to be a dict[str, str]")
TASK_DESCRIPTIONS: dict[str, str] = task_descriptions_obj
task_name_to_id_obj = data.get("TASK_NAME_TO_ID")
if not isinstance(task_name_to_id_obj, dict):
raise TypeError("Expected TASK_NAME_TO_ID to be a dict[str, int]")
TASK_NAME_TO_ID: dict[str, int] = task_name_to_id_obj
# difficulty -> tasks mapping
difficulty_to_tasks = data.get("DIFFICULTY_TO_TASKS")
if not isinstance(difficulty_to_tasks, dict):
raise TypeError("Expected 'DIFFICULTY_TO_TASKS' to be a dict[str, list[str]]")
DIFFICULTY_TO_TASKS: dict[str, list[str]] = difficulty_to_tasks
# convert policy strings -> actual policy classes
task_policy_mapping = data.get("TASK_POLICY_MAPPING")
if not isinstance(task_policy_mapping, dict):
raise TypeError("Expected 'TASK_POLICY_MAPPING' to be a dict[str, str]")
TASK_POLICY_MAPPING: dict[str, Any] = {
task_name: getattr(policies, policy_class_name)
for task_name, policy_class_name in task_policy_mapping.items()
}
ACTION_DIM = 4
OBS_DIM = 4
class MetaworldEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
def __init__(
self,
task,
camera_name="corner2",
obs_type="pixels",
render_mode="rgb_array",
observation_width=480,
observation_height=480,
visualization_width=640,
visualization_height=480,
):
super().__init__()
self.task = task.replace("metaworld-", "")
self.obs_type = obs_type
self.render_mode = render_mode
self.observation_width = observation_width
self.observation_height = observation_height
self.visualization_width = visualization_width
self.visualization_height = visualization_height
self.camera_name = camera_name
self._env = self._make_envs_task(self.task)
self._max_episode_steps = self._env.max_path_length
self.task_description = TASK_DESCRIPTIONS[self.task]
self.expert_policy = TASK_POLICY_MAPPING[self.task]()
if self.obs_type == "state":
raise NotImplementedError()
elif self.obs_type == "pixels":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
}
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(OBS_DIM,),
dtype=np.float64,
),
}
)
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
def render(self) -> np.ndarray:
"""
Render the current environment frame.
Returns:
np.ndarray: The rendered RGB image from the environment.
"""
image = self._env.render()
if self.camera_name == "corner2":
# Images from this camera are flipped — correct them
image = np.flip(image, (0, 1))
return image
def _make_envs_task(self, env_name: str):
mt1 = metaworld.MT1(env_name, seed=42)
env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name)
env.set_task(mt1.train_tasks[0])
if self.camera_name == "corner2":
env.model.cam_pos[2] = [
0.75,
0.075,
0.7,
] # corner2 position, similar to https://arxiv.org/pdf/2206.14244
env.reset()
env._freeze_rand_vec = False # otherwise no randomization
return env
def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
image = None
if self._env is not None:
image = self._env.render()
if self.camera_name == "corner2":
# NOTE: The "corner2" camera in MetaWorld environments outputs images with both axes inverted.
image = np.flip(image, (0, 1))
agent_pos = raw_obs[:4]
if self.obs_type == "state":
raise NotImplementedError(
"'state' obs_type not implemented for MetaWorld. Use pixel modes instead."
)
elif self.obs_type in ("pixels", "pixels_agent_pos"):
assert image is not None, (
"Expected `image` to be rendered before constructing pixel-based observations. "
"This likely means `env.render()` returned None or the environment was not provided."
)
if self.obs_type == "pixels":
obs = {"pixels": image.copy()}
else: # pixels_agent_pos
obs = {
"pixels": image.copy(),
"agent_pos": agent_pos,
}
else:
raise ValueError(f"Unknown obs_type: {self.obs_type}")
return obs
def reset(
self,
seed: int | None = None,
**kwargs,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Reset the environment to its initial state.
Args:
seed (Optional[int]): Random seed for environment initialization.
Returns:
observation (Dict[str, Any]): The initial formatted observation.
info (Dict[str, Any]): Additional info about the reset state.
"""
super().reset(seed=seed)
raw_obs, info = self._env.reset(seed=seed)
observation = self._format_raw_obs(raw_obs)
info = {"is_success": False}
return observation, info
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
"""
Perform one environment step.
Args:
action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
Returns:
observation (Dict[str, Any]): The formatted observation after the step.
reward (float): The scalar reward for this step.
terminated (bool): Whether the episode terminated successfully.
truncated (bool): Whether the episode was truncated due to a time limit.
info (Dict[str, Any]): Additional environment info.
"""
if action.ndim != 1:
raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), "
f"but got shape {action.shape} with ndim={action.ndim}"
)
raw_obs, reward, done, truncated, info = self._env.step(action)
# Determine whether the task was successful
is_success = bool(info.get("success", 0))
terminated = done or is_success
info.update(
{
"task": self.task,
"done": done,
"is_success": is_success,
}
)
# Format the raw observation into the expected structure
observation = self._format_raw_obs(raw_obs)
if terminated:
info["final_info"] = {
"task": self.task,
"done": bool(done),
"is_success": bool(is_success),
}
self.reset()
return observation, reward, terminated, truncated, info
def close(self):
self._env.close()
# ---- Main API ----------------------------------------------------------------
def create_metaworld_envs(
task: str,
n_envs: int,
gym_kwargs: dict[str, Any] | None = None,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
) -> dict[str, dict[int, Any]]:
"""
Create vectorized Meta-World environments with a consistent return shape.
Returns:
dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
Notes:
- n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
- `task` can be a single difficulty group (e.g., "easy", "medium", "hard") or a comma-separated list.
- If a task name is not in DIFFICULTY_TO_TASKS, we treat it as a single custom task.
"""
if env_cls is None or not callable(env_cls):
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
if not isinstance(n_envs, int) or n_envs <= 0:
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
gym_kwargs = dict(gym_kwargs or {})
task_groups = [t.strip() for t in task.split(",") if t.strip()]
if not task_groups:
raise ValueError("`task` must contain at least one Meta-World task or difficulty group.")
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
for group in task_groups:
# if not in difficulty presets, treat it as a single custom task
tasks = DIFFICULTY_TO_TASKS.get(group, [group])
for tid, task_name in enumerate(tasks):
print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
# build n_envs factories
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
out[group][tid] = env_cls(fns)
# return a plain dict for consistency
return {group: dict(task_map) for group, task_map in out.items()}

View File

@@ -0,0 +1,121 @@
{
"TASK_DESCRIPTIONS": {
"assembly-v3": "Pick up a nut and place it onto a peg",
"basketball-v3": "Dunk the basketball into the basket",
"bin-picking-v3": "Grasp the puck from one bin and place it into another bin",
"box-close-v3": "Grasp the cover and close the box with it",
"button-press-topdown-v3": "Press a button from the top",
"button-press-topdown-wall-v3": "Bypass a wall and press a button from the top",
"button-press-v3": "Press a button",
"button-press-wall-v3": "Bypass a wall and press a button",
"coffee-button-v3": "Push a button on the coffee machine",
"coffee-pull-v3": "Pull a mug from a coffee machine",
"coffee-push-v3": "Push a mug under a coffee machine",
"dial-turn-v3": "Rotate a dial 180 degrees",
"disassemble-v3": "Pick a nut out of a peg",
"door-close-v3": "Close a door with a revolving joint",
"door-lock-v3": "Lock the door by rotating the lock clockwise",
"door-open-v3": "Open a door with a revolving joint",
"door-unlock-v3": "Unlock the door by rotating the lock counter-clockwise",
"hand-insert-v3": "Insert the gripper into a hole",
"drawer-close-v3": "Push and close a drawer",
"drawer-open-v3": "Open a drawer",
"faucet-open-v3": "Rotate the faucet counter-clockwise",
"faucet-close-v3": "Rotate the faucet clockwise",
"hammer-v3": "Hammer a screw on the wall",
"handle-press-side-v3": "Press a handle down sideways",
"handle-press-v3": "Press a handle down",
"handle-pull-side-v3": "Pull a handle up sideways",
"handle-pull-v3": "Pull a handle up",
"lever-pull-v3": "Pull a lever down 90 degrees",
"peg-insert-side-v3": "Insert a peg sideways",
"pick-place-wall-v3": "Pick a puck, bypass a wall and place the puck",
"pick-out-of-hole-v3": "Pick up a puck from a hole",
"reach-v3": "Reach a goal position",
"push-back-v3": "Push the puck to a goal",
"push-v3": "Push the puck to a goal",
"pick-place-v3": "Pick and place a puck to a goal",
"plate-slide-v3": "Slide a plate into a cabinet",
"plate-slide-side-v3": "Slide a plate into a cabinet sideways",
"plate-slide-back-v3": "Get a plate from the cabinet",
"plate-slide-back-side-v3": "Get a plate from the cabinet sideways",
"peg-unplug-side-v3": "Unplug a peg sideways",
"soccer-v3": "Kick a soccer into the goal",
"stick-push-v3": "Grasp a stick and push a box using the stick",
"stick-pull-v3": "Grasp a stick and pull a box with the stick",
"push-wall-v3": "Bypass a wall and push a puck to a goal",
"reach-wall-v3": "Bypass a wall and reach a goal",
"shelf-place-v3": "Pick and place a puck onto a shelf",
"sweep-into-v3": "Sweep a puck into a hole",
"sweep-v3": "Sweep a puck off the table",
"window-open-v3": "Push and open a window",
"window-close-v3": "Push and close a window"
},
"TASK_NAME_TO_ID": {
"assembly-v3": 0, "basketball-v3": 1, "bin-picking-v3": 2, "box-close-v3": 3,
"button-press-topdown-v3": 4, "button-press-topdown-wall-v3": 5, "button-press-v3": 6,
"button-press-wall-v3": 7, "coffee-button-v3": 8, "coffee-pull-v3": 9, "coffee-push-v3": 10,
"dial-turn-v3": 11, "disassemble-v3": 12, "door-close-v3": 13, "door-lock-v3": 14,
"door-open-v3": 15, "door-unlock-v3": 16, "drawer-close-v3": 17, "drawer-open-v3": 18,
"faucet-close-v3": 19, "faucet-open-v3": 20, "hammer-v3": 21, "hand-insert-v3": 22,
"handle-press-side-v3": 23, "handle-press-v3": 24, "handle-pull-side-v3": 25,
"handle-pull-v3": 26, "lever-pull-v3": 27, "peg-insert-side-v3": 28, "peg-unplug-side-v3": 29,
"pick-out-of-hole-v3": 30, "pick-place-v3": 31, "pick-place-wall-v3": 32,
"plate-slide-back-side-v3": 33, "plate-slide-back-v3": 34, "plate-slide-side-v3": 35,
"plate-slide-v3": 36, "push-back-v3": 37, "push-v3": 38, "push-wall-v3": 39, "reach-v3": 40,
"reach-wall-v3": 41, "shelf-place-v3": 42, "soccer-v3": 43, "stick-pull-v3": 44,
"stick-push-v3": 45, "sweep-into-v3": 46, "sweep-v3": 47, "window-open-v3": 48,
"window-close-v3": 49
},
"DIFFICULTY_TO_TASKS": {
"easy": [
"button-press-v3", "button-press-topdown-v3", "button-press-topdown-wall-v3",
"button-press-wall-v3", "coffee-button-v3", "dial-turn-v3", "door-close-v3",
"door-lock-v3", "door-open-v3", "door-unlock-v3", "drawer-close-v3", "drawer-open-v3",
"faucet-close-v3", "faucet-open-v3", "handle-press-v3", "handle-press-side-v3",
"handle-pull-v3", "handle-pull-side-v3", "lever-pull-v3", "plate-slide-v3",
"plate-slide-back-v3", "plate-slide-back-side-v3", "plate-slide-side-v3", "reach-v3",
"reach-wall-v3", "window-close-v3", "window-open-v3", "peg-unplug-side-v3"
],
"medium": [
"basketball-v3", "bin-picking-v3", "box-close-v3", "coffee-pull-v3", "coffee-push-v3",
"hammer-v3", "peg-insert-side-v3", "push-wall-v3", "soccer-v3", "sweep-v3", "sweep-into-v3"
],
"hard": [
"assembly-v3", "hand-insert-v3", "pick-out-of-hole-v3", "pick-place-v3", "push-v3", "push-back-v3"
],
"very_hard": [
"shelf-place-v3", "disassemble-v3", "stick-pull-v3", "stick-push-v3", "pick-place-wall-v3"
]
},
"TASK_POLICY_MAPPING": {
"assembly-v3": "SawyerAssemblyV3Policy", "basketball-v3": "SawyerBasketballV3Policy",
"bin-picking-v3": "SawyerBinPickingV3Policy", "box-close-v3": "SawyerBoxCloseV3Policy",
"button-press-topdown-v3": "SawyerButtonPressTopdownV3Policy",
"button-press-topdown-wall-v3": "SawyerButtonPressTopdownWallV3Policy",
"button-press-v3": "SawyerButtonPressV3Policy", "button-press-wall-v3": "SawyerButtonPressWallV3Policy",
"coffee-button-v3": "SawyerCoffeeButtonV3Policy", "coffee-pull-v3": "SawyerCoffeePullV3Policy",
"coffee-push-v3": "SawyerCoffeePushV3Policy", "dial-turn-v3": "SawyerDialTurnV3Policy",
"disassemble-v3": "SawyerDisassembleV3Policy", "door-close-v3": "SawyerDoorCloseV3Policy",
"door-lock-v3": "SawyerDoorLockV3Policy", "door-open-v3": "SawyerDoorOpenV3Policy",
"door-unlock-v3": "SawyerDoorUnlockV3Policy", "drawer-close-v3": "SawyerDrawerCloseV3Policy",
"drawer-open-v3": "SawyerDrawerOpenV3Policy", "faucet-close-v3": "SawyerFaucetCloseV3Policy",
"faucet-open-v3": "SawyerFaucetOpenV3Policy", "hammer-v3": "SawyerHammerV3Policy",
"hand-insert-v3": "SawyerHandInsertV3Policy", "handle-press-side-v3": "SawyerHandlePressSideV3Policy",
"handle-press-v3": "SawyerHandlePressV3Policy", "handle-pull-side-v3": "SawyerHandlePullSideV3Policy",
"handle-pull-v3": "SawyerHandlePullV3Policy", "lever-pull-v3": "SawyerLeverPullV3Policy",
"peg-insert-side-v3": "SawyerPegInsertionSideV3Policy", "peg-unplug-side-v3": "SawyerPegUnplugSideV3Policy",
"pick-out-of-hole-v3": "SawyerPickOutOfHoleV3Policy", "pick-place-v3": "SawyerPickPlaceV3Policy",
"pick-place-wall-v3": "SawyerPickPlaceWallV3Policy",
"plate-slide-back-side-v3": "SawyerPlateSlideBackSideV3Policy",
"plate-slide-back-v3": "SawyerPlateSlideBackV3Policy",
"plate-slide-side-v3": "SawyerPlateSlideSideV3Policy", "plate-slide-v3": "SawyerPlateSlideV3Policy",
"push-back-v3": "SawyerPushBackV3Policy", "push-v3": "SawyerPushV3Policy",
"push-wall-v3": "SawyerPushWallV3Policy", "reach-v3": "SawyerReachV3Policy",
"reach-wall-v3": "SawyerReachWallV3Policy", "shelf-place-v3": "SawyerShelfPlaceV3Policy",
"soccer-v3": "SawyerSoccerV3Policy", "stick-pull-v3": "SawyerStickPullV3Policy",
"stick-push-v3": "SawyerStickPushV3Policy", "sweep-into-v3": "SawyerSweepIntoV3Policy",
"sweep-v3": "SawyerSweepV3Policy", "window-open-v3": "SawyerWindowOpenV3Policy",
"window-close-v3": "SawyerWindowCloseV3Policy"
}
}

View File

@@ -31,7 +31,6 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
@@ -58,7 +57,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
Returns:
The policy class corresponding to the given name.
@@ -82,10 +81,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0fast":
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
@@ -119,7 +114,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -137,8 +132,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi05":
@@ -260,14 +253,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
processors = make_pi0fast_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors

View File

@@ -897,7 +897,7 @@ class PI0Policy(PreTrainedPolicy):
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"The PI05 model is a direct port of the OpenPI implementation. \n"
"The PI0 model is a direct port of the OpenPI implementation. \n"
"This implementation follows the original OpenPI structure for compatibility. \n"
"Original implementation: https://github.com/Physical-Intelligence/openpi"
)

View File

@@ -1,153 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0fast")
@dataclass
class PI0FASTConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 10
n_action_steps: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32 # 32
max_action_dim: int = 32 # 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
interpolate_like_pi: bool = False
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
max_decoding_steps: int = 256
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
max_input_seq_len: int = 256 # 512
# Utils
use_cache: bool = True
# Frozen parameters
freeze_vision_encoder: bool = True
freeze_lm_head: bool = True
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
checkpoint_path: str = None
padding_side: str = "right"
precision: str = "bfloat16"
grad_clip_norm: float = 1
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
relaxed_action_decoding: bool = True
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,980 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
[Paper](https://huggingface.co/papers/2501.09747)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Disclaimer: It is not expected to perform as well as the original implementation.
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
lerobot-train \
--policy.path=lerobot/pi0fast_base \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of training the pi0+FAST neural network with from scratch:
```bash
lerobot-train \
--policy.type=pi0fast \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
```
"""
from collections import deque
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from scipy.fft import idct
from torch import Tensor, nn
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
PRECISION = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0FASTPolicy(PreTrainedPolicy):
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
config_class = PI0FASTConfig
name = "pi0fast"
def __init__(
self,
config: PI0FASTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
" It is not expected to perform as well as the original implementation. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)
def get_optim_params(self) -> dict:
return self.parameters()
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model.generate_actions(batch)
actions = actions[:, : self.config.n_action_steps]
original_action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict
def block_causal_update_causal_mask(
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
attn_implementation: str = "eager",
dtype: torch.dtype = "float32",
):
"""
Update the causal mask during training and generation. It can be customized to different attention masks.
"""
if attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache or isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
# Handle precomputed attention masks
if attention_mask is not None and attention_mask.dim() == 4:
return attention_mask
# Causal mask initialization
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
# Standard causal masking (triu ensures tokens can only attend to past)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
# Apply block causal mask
if token_type_ids is not None:
token_type_ids = token_type_ids.to(causal_mask.device).bool()
cumsum = torch.cumsum(token_type_ids, dim=1)
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
# Combine causal_mask with block-wise attention mask
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
causal_mask = causal_mask[:, None, :, :]
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
mask_length = attention_mask.shape[-1]
# Apply padding mask
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def prepare_inputs_for_generation(
# self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
num_logits_to_keep=None,
labels=None,
self=None,
**kwargs,
):
# create block causal attention
if cache_position[0] > 0 and input_ids.shape[1] > 0:
input_tensor = input_ids[:, -1:]
new_positions = (
torch.ones(
(position_ids.shape[0], input_ids.shape[1]),
dtype=position_ids.dtype,
device=position_ids.device,
).cumsum(-1)
+ position_ids[:, -1:]
)
position_ids = torch.cat([position_ids, new_positions], dim=-1)
else:
input_tensor = inputs_embeds
attention_mask = block_causal_update_causal_mask(
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
input_tensor=input_tensor,
token_type_ids=token_type_ids,
dtype=self.dtype,
attn_implementation=self.config.text_config._attn_implementation,
)
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
num_logits_to_keep=num_logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# Position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
class PI0FAST(nn.Module):
def __init__(self, config: PI0FASTConfig):
super().__init__()
self.config = config
# TODO: move tokenizers in Policy
fast_tokenizer_path = "physical-intelligence/fast"
pi0_paligemma_path = "google/paligemma-3b-pt-224"
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self.fast_skip_tokens = self.config.fast_skip_tokens
self.max_input_seq_len = self.config.max_input_seq_len
self.action_horizon = self.config.chunk_size
self.action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
precision = config.precision
torch_precision = PRECISION.get(precision, torch.float32)
self.pad_token_id = (
self.paligemma_tokenizer.pad_token_id
if hasattr(self.paligemma_tokenizer, "pad_token_id")
else self.paligemma_tokenizer.eos_token_id
)
paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": precision,
"vocab_size": 257152,
"_attn_implementation": "eager",
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_pytorch_tanh",
"torch_dtype": precision,
"vision_use_head": False,
},
)
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma
)
# change important stuff in bf16
params_to_change_dtype = [
"language_model",
"vision_tower",
"multi_modal",
]
for name, param in self.pi0_paligemma.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
# TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed
# AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index'
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.pi0_paligemma.vision_tower.eval()
for params in self.pi0_paligemma.vision_tower.parameters():
params.requires_grad = False
# To avoid unused params issue with distributed training
if self.config.freeze_lm_head:
for name, params in self.pi0_paligemma.named_parameters():
if "embed_tokens" in name: # lm heads and embedding layer are tied
params.requires_grad = False
def embed_tokens(self, tokens: torch.Tensor):
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
def prepare_images(self, batch):
"""Preprocess LeRobot batch into Pi0 inputs"""
images = []
img_masks = []
present_img_keys = [key for key in self.image_keys if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
num_empty_cameras = 0
for key in self.image_keys:
if key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(
img,
*self.config.resize_imgs_with_padding,
pad_value=0,
interpolate_like_pi=self.config.interpolate_like_pi,
)
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
else:
if num_empty_cameras >= self.config.empty_cameras:
continue
img = torch.ones_like(img) * -1
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
num_empty_cameras += 1
images.append(img)
img_masks.append(mask)
return images, img_masks
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
return out
def fast_tokenizer_wrapper(self, actions_norm):
"""
A wrapper for self.fast_tokenizer that ensures batch processing,
conversion to PyTorch tensors, and returns a dictionary without padding.
"""
batch_tokens = self.fast_tokenizer(actions_norm)
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
return fast_out
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
# Compute cumulative sum mask
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
# Suffix block (everything after prefix_len)
suffix_mask = cumsum_mask > prefix_len
token_type_ids = suffix_mask
return token_type_ids
def create_input_tokens(self, state, lang_text, actions=None):
bsize = state.shape[0]
device = state.device
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
discretized = torch.bucketize(state, bins) - 1
discretized = discretized[:, :32]
prefix_texts = []
state_text = []
for txt, disc in zip(lang_text, discretized, strict=False):
cleaned = txt.lower().strip().replace("_", " ")
state_str = " ".join(str(val.item()) for val in disc)
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
state_text.append(f"State: {state_str};\n")
prefix_out = self.paligemma_tokenizer(
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
)
prefix_ids = prefix_out["input_ids"].to(device)
prefix_mask = prefix_out["attention_mask"].to(device)
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
if actions is not None:
actions_norm = self.normalize_actions(actions)
actions_pad = F.pad(
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
)[:, :, : self.config.max_action_dim]
fast_out = self.fast_tokenizer_wrapper(
actions_pad.cpu(),
)
act_ids = fast_out["input_ids"]
act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
# Replace action with 0 to pad tokens
act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
self.pad_token_id,
act_ids,
)
eos_token = torch.tensor(
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device)
else:
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
# Use tokenizer pad function
padded_output = self.paligemma_tokenizer.pad(
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
)
padded_mask = padded_output["attention_mask"]
# define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
padded_output["padded_mask"] = padded_output.pop("attention_mask")
padded_output["attention_mask"] = att_mask
# loss is computed not on prefix, and not on padding
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
padded_output["token_type_ids"] = token_type_ids
return padded_output
def shift_padding_side(
self,
tokens: torch.Tensor,
ar_mask: torch.Tensor,
padding_mask: torch.Tensor,
loss_mask: torch.Tensor,
targets: torch.Tensor,
token_type_ids: torch.Tensor,
padding_side: str = "right",
) -> tuple[torch.Tensor]:
if padding_side not in ["right", "left"]:
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
new_tokens = torch.empty_like(tokens)
new_ar_masks = torch.empty_like(ar_mask)
new_padding_mask = torch.empty_like(padding_mask)
new_loss_mask = torch.empty_like(loss_mask)
new_targets = torch.empty_like(targets)
new_token_type_ids = torch.empty_like(token_type_ids)
batch_size = tokens.shape[0]
for i in range(batch_size):
padding_indices = torch.where(padding_mask[i] == 0)[0]
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
if padding_side == "left":
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
else:
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
new_tokens[i] = tokens[i].index_select(0, new_indices)
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
new_targets[i] = targets[i].index_select(0, new_indices)
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
def forward(self, batch: dict[str, Tensor]):
device = batch[OBS_STATE].device
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(
state=batch[OBS_STATE],
lang_text=batch["task"],
actions=batch[ACTION],
)
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side=self.padding_side,
)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
token_type_ids = token_type_ids.to(dtype=torch.int64)
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
pad_masks = block_causal_update_causal_mask(
attention_mask=pad_masks,
past_key_values=None,
cache_position=cache_position,
input_tensor=embs,
token_type_ids=token_type_ids,
dtype=self.pi0_paligemma.dtype,
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
)
outputs = self.pi0_paligemma.forward(
input_ids=None,
token_type_ids=None,
attention_mask=pad_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=False,
labels=None,
)
logits = outputs.logits
loss_fct = nn.CrossEntropyLoss(reduction="none")
# Shift left for next-step prediction
logits = logits[:, :-1, :]
targets = targets[:, 1:].to(device) # Shift targets
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
# Compute per-token loss
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
# Apply loss mask
token_loss = token_loss * loss_mask.reshape(-1)
# Compute final loss
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
# Return loss dictionary
loss_dict = {"ce_loss": loss.item(), "loss": loss}
return loss_dict
def decode_actions_with_fast(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
relaxed_decoding: bool = True,
) -> np.array:
"""
Adapt original decoding in FAST to always return actions instead of zeros.
"""
self.time_horizon = (
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
)
self.action_dim = (
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
)
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
if relaxed_decoding:
# Expected sequence length
expected_seq_len = self.time_horizon * self.action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# Apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
# Apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens (torch.Tensor): The input tensor of tokenized outputs.
action_horizon (int): The number of timesteps for actions.
action_dim (int): The dimensionality of each action.
Returns:
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
"""
# Decode predicted output tokens
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
cleaned_tokens = [
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens
]
raw_action_tokens = [
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens
] # something like this should be robust #looks good
action_tokens = [
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
]
# returns the tensor of decoded actions per sample in a list
decoded_actions = [
torch.tensor(
self.decode_actions_with_fast(
tok.tolist(),
time_horizon=action_horizon,
action_dim=action_dim,
relaxed_decoding=self.config.relaxed_action_decoding,
),
device=tokens.device,
).squeeze(0)
for tok in action_tokens
]
return torch.stack(
decoded_actions,
dim=0,
)
def generate_actions(self, batch: dict[str, Tensor]):
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side="left",
)
token_type_ids = token_type_ids.to(dtype=torch.int64)
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
output_tokens = self.pi0_paligemma.generate(
input_ids=None,
attention_mask=pad_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=self.config.use_cache,
max_new_tokens=self.config.max_decoding_steps,
do_sample=False,
num_beams=1,
token_type_ids=token_type_ids,
)
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
return actions
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
if hasattr(self.pi0_paligemma, "get_image_features"):
return self.pi0_paligemma.get_image_features(image)
else:
return self.pi0_paligemma.model.get_image_features(image)
def embed_inputs(
self,
images,
img_masks,
tokens,
pad_mask,
ar_mask,
loss_mask,
token_type_ids,
padding_side: str = "right",
):
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
# images are a list of same size
# vectorizing everything!
device = images[0].device
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
all_images = torch.stack(images, dim=1).to(device)
b, n, c, h, w = all_images.shape
all_images = all_images.view(b * n, c, h, w)
embedded = self.embed_image(all_images).to(device)
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
m = b_n // b # Compute the number of images per sample dynamically
# Reshape dynamically
embedded = embedded.view(b, m, p, image_embedding_dim)
tokens_embs = self.embed_tokens(tokens.to(device))
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
num_img_emb = embedded.shape[2]
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
image_target_tokens = (
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
).reshape(b, -1)
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
# Shift pad tokens to the left (.generate()) or right (.train())
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
)
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
if interpolate_like_pi:
img = (img * 255.0).to(dtype=torch.uint8)
img = img.permute(0, 2, 3, 1)
original_device = img.device
img = img.to(device="cpu").numpy()
imgs = []
for sub_img in img:
sub_img = Image.fromarray(sub_img)
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
resized_img = torch.from_numpy(np.array(resized_img))
imgs.append(resized_img)
img = torch.stack(imgs, dim=0)
img = img.permute(0, 3, 1, 2)
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
else:
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img

View File

@@ -1,92 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_pi0fast_pre_post_processors(
config: PI0FASTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0Fast policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -16,10 +16,16 @@
import logging
from collections import deque
from typing import Any
import numpy as np
import torch
from torch import nn
from lerobot.datasets.utils import build_dataset_frame
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
def populate_queues(
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
@@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str])
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
# TODO(Steven): Move this function to a proper preprocessor step
def prepare_observation_for_inference(
observation: dict[str, np.ndarray],
device: torch.device,
task: str | None = None,
robot_type: str | None = None,
) -> RobotObservation:
"""Converts observation data to model-ready PyTorch tensors.
This function takes a dictionary of NumPy arrays, performs necessary
preprocessing, and prepares it for model inference. The steps include:
1. Converting NumPy arrays to PyTorch tensors.
2. Normalizing and permuting image data (if any).
3. Adding a batch dimension to each tensor.
4. Moving all tensors to the specified compute device.
5. Adding task and robot type information to the dictionary.
Args:
observation: A dictionary mapping observation names (str) to NumPy
array data. For images, the format is expected to be (H, W, C).
device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the
tensors will be moved.
task: An optional string identifier for the current task.
robot_type: An optional string identifier for the robot being used.
Returns:
A dictionary where values are PyTorch tensors preprocessed for
inference, residing on the target device. Image tensors are reshaped
to (C, H, W) and normalized to a [0, 1] range.
"""
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
return observation
def build_inference_frame(
observation: dict[str, Any],
device: torch.device,
ds_features: dict[str, dict],
task: str | None = None,
robot_type: str | None = None,
) -> RobotObservation:
"""Constructs a model-ready observation tensor dict from a raw observation.
This utility function orchestrates the process of converting a raw,
unstructured observation from an environment into a structured,
tensor-based format suitable for passing to a policy model.
Args:
observation: The raw observation dictionary, which may contain
superfluous keys.
device: The target PyTorch device for the final tensors.
ds_features: A configuration dictionary that specifies which features
to extract from the raw observation.
task: An optional string identifier for the current task.
robot_type: An optional string identifier for the robot being used.
Returns:
A dictionary of preprocessed tensors ready for model inference.
"""
# Extracts the correct keys from the incoming raw observation
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)
# Performs the necessary conversions to the observation
observation = prepare_observation_for_inference(observation, device, task, robot_type)
return observation
def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction:
"""Converts a policy's output tensor into a dictionary of named actions.
This function translates the numerical output from a policy model into a
human-readable and robot-consumable format, where each dimension of the
action tensor is mapped to a named motor or actuator command.
Args:
action_tensor: A PyTorch tensor representing the policy's action,
typically with a batch dimension (e.g., shape [1, action_dim]).
ds_features: A configuration dictionary containing metadata, including
the names corresponding to each index of the action tensor.
Returns:
A dictionary mapping action names (e.g., "joint_1_motor") to their
corresponding floating-point values, ready to be sent to a robot
controller.
"""
# TODO(Steven): Check if these steps are already in all postprocessor policies
action_tensor = action_tensor.squeeze(0)
action_tensor = action_tensor.to("cpu")
action_names = ds_features[ACTION]["names"]
act_processed_policy: RobotAction = {
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
}
return act_processed_policy

View File

@@ -180,9 +180,15 @@ def rollout(
render_callback(env)
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
# available if none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
)
successes = final_info["is_success"].tolist()
else:
successes = [False] * env.num_envs

View File

@@ -79,6 +79,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
@@ -316,10 +317,7 @@ def record_loop(
robot_type=robot.robot_type,
)
action_names = dataset.features[ACTION]["names"]
act_processed_policy: RobotAction = {
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
}
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()

View File

@@ -19,8 +19,6 @@
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
{% elif model_name == "vqbet" %}
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
{% elif model_name == "pi0fast" %}
[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time.
{% elif model_name == "pi0" %}
**π₀ (Pi0)**

View File

@@ -31,6 +31,7 @@ from deepdiff import DeepDiff
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.robots import Robot
@@ -102,17 +103,7 @@ def predict_action(
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation = preprocessor(observation)
# Compute the next action with the policy
@@ -121,12 +112,6 @@ def predict_action(
action = postprocessor(action)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action

View File

@@ -22,9 +22,10 @@ import pytest
import torch
from lerobot.datasets.dataset_tools import (
add_feature,
add_features,
delete_episodes,
merge_datasets,
modify_features,
remove_feature,
split_dataset,
)
@@ -292,7 +293,7 @@ def test_merge_empty_list(tmp_path):
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
def test_add_feature_with_values(sample_dataset, tmp_path):
def test_add_features_with_values(sample_dataset, tmp_path):
"""Test adding a feature with pre-computed values."""
num_frames = sample_dataset.meta.total_frames
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
@@ -302,6 +303,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
"shape": (1,),
"names": None,
}
features = {
"reward": (reward_values, feature_info),
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
@@ -310,11 +314,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=reward_values,
feature_info=feature_info,
new_dataset = add_features(
dataset=sample_dataset,
features=features,
output_dir=tmp_path / "with_reward",
)
@@ -327,7 +329,7 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
assert isinstance(sample_item["reward"], torch.Tensor)
def test_add_feature_with_callable(sample_dataset, tmp_path):
def test_add_features_with_callable(sample_dataset, tmp_path):
"""Test adding a feature with a callable."""
def compute_reward(frame_dict, episode_idx, frame_idx):
@@ -338,7 +340,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
"shape": (1,),
"names": None,
}
features = {
"reward": (compute_reward, feature_info),
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
@@ -346,11 +350,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=compute_reward,
feature_info=feature_info,
new_dataset = add_features(
dataset=sample_dataset,
features=features,
output_dir=tmp_path / "with_reward",
)
@@ -368,31 +370,88 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
def test_add_existing_feature(sample_dataset, tmp_path):
"""Test error when adding an existing feature."""
feature_info = {"dtype": "float32", "shape": (1,)}
features = {
"action": (np.zeros(50), feature_info),
}
with pytest.raises(ValueError, match="Feature 'action' already exists"):
add_feature(
sample_dataset,
feature_name="action",
feature_values=np.zeros(50),
feature_info=feature_info,
add_features(
dataset=sample_dataset,
features=features,
output_dir=tmp_path / "modified",
)
def test_add_feature_invalid_info(sample_dataset, tmp_path):
"""Test error with invalid feature info."""
with pytest.raises(ValueError, match="feature_info must contain keys"):
add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.zeros(50),
feature_info={"dtype": "float32"},
with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"):
add_features(
dataset=sample_dataset,
features={
"reward": (np.zeros(50), {"dtype": "float32"}),
},
output_dir=tmp_path / "modified",
)
def test_remove_single_feature(sample_dataset, tmp_path):
"""Test removing a single feature."""
def test_modify_features_add_and_remove(sample_dataset, tmp_path):
"""Test modifying features by adding and removing simultaneously."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "modified")
# First add a feature we'll later remove
dataset_with_reward = add_features(
sample_dataset,
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
output_dir=tmp_path / "with_reward",
)
# Now use modify_features to add "success" and remove "reward" in one pass
modified_dataset = modify_features(
dataset_with_reward,
add_features={
"success": (np.random.randn(50, 1).astype(np.float32), feature_info),
},
remove_features="reward",
output_dir=tmp_path / "modified",
)
assert "success" in modified_dataset.meta.features
assert "reward" not in modified_dataset.meta.features
assert len(modified_dataset) == 50
def test_modify_features_only_add(sample_dataset, tmp_path):
"""Test that modify_features works with only add_features."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "modified")
modified_dataset = modify_features(
sample_dataset,
add_features={
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
},
output_dir=tmp_path / "modified",
)
assert "reward" in modified_dataset.meta.features
assert len(modified_dataset) == 50
def test_modify_features_only_remove(sample_dataset, tmp_path):
"""Test that modify_features works with only remove_features."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
@@ -402,11 +461,46 @@ def test_remove_single_feature(sample_dataset, tmp_path):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset_with_reward = add_feature(
dataset_with_reward = add_features(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info=feature_info,
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
output_dir=tmp_path / "with_reward",
)
modified_dataset = modify_features(
dataset_with_reward,
remove_features="reward",
output_dir=tmp_path / "modified",
)
assert "reward" not in modified_dataset.meta.features
def test_modify_features_no_changes(sample_dataset, tmp_path):
"""Test error when modify_features is called with no changes."""
with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"):
modify_features(
sample_dataset,
output_dir=tmp_path / "modified",
)
def test_remove_single_feature(sample_dataset, tmp_path):
"""Test removing a single feature."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
features = {
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset_with_reward = add_features(
dataset=sample_dataset,
features=features,
output_dir=tmp_path / "with_reward",
)
@@ -432,20 +526,19 @@ def test_remove_multiple_features(sample_dataset, tmp_path):
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset = sample_dataset
features = {}
for feature_name in ["reward", "success"]:
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
dataset = add_feature(
dataset,
feature_name=feature_name,
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / f"with_{feature_name}",
features[feature_name] = (
np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
feature_info,
)
dataset_with_features = add_features(
dataset, features=features, output_dir=tmp_path / "with_features"
)
dataset_clean = remove_feature(
dataset,
feature_names=["reward", "success"],
output_dir=tmp_path / "clean",
dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean"
)
assert "reward" not in dataset_clean.meta.features
@@ -509,11 +602,14 @@ def test_complex_workflow_integration(sample_dataset, tmp_path):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset = add_feature(
dataset = add_features(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info={"dtype": "float32", "shape": (1,), "names": None},
features={
"reward": (
np.random.randn(50, 1).astype(np.float32),
{"dtype": "float32", "shape": (1,), "names": None},
)
},
output_dir=tmp_path / "step1",
)
@@ -753,7 +849,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f
assert "std" in merged.meta.stats[feature]
def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
def test_add_features_preserves_existing_stats(sample_dataset, tmp_path):
"""Test that adding a feature preserves existing stats."""
num_frames = sample_dataset.meta.total_frames
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
@@ -763,6 +859,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
"shape": (1,),
"names": None,
}
features = {
"reward": (reward_values, feature_info),
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
@@ -771,11 +870,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=reward_values,
feature_info=feature_info,
new_dataset = add_features(
dataset=sample_dataset,
features=features,
output_dir=tmp_path / "with_reward",
)
@@ -797,11 +894,11 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset_with_reward = add_feature(
dataset_with_reward = add_features(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info=feature_info,
features={
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
},
output_dir=tmp_path / "with_reward",
)
@@ -893,3 +990,60 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path):
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
assert total_episodes == sample_dataset.meta.total_episodes
def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
"""Test that modifying features preserves chunk_idx and file_idx from source dataset."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1]))
mock_snapshot_download.side_effect = mock_snapshot
# First split the dataset to create a non-zero starting chunk/file structure
splits = split_dataset(
sample_dataset,
splits={"train": [0, 1, 2], "val": [3, 4]},
output_dir=tmp_path / "splits",
)
train_dataset = splits["train"]
# Get original chunk/file indices from first episode
if train_dataset.meta.episodes is None:
from lerobot.datasets.utils import load_episodes
train_dataset.meta.episodes = load_episodes(train_dataset.meta.root)
original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes]
original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes]
# Now add a feature to the split dataset
modified_dataset = add_features(
train_dataset,
features={
"reward": (
np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32),
feature_info,
),
},
output_dir=tmp_path / "modified",
)
# Check that chunk/file indices are preserved
if modified_dataset.meta.episodes is None:
from lerobot.datasets.utils import load_episodes
modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root)
new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes]
new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes]
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
assert new_file_indices == original_file_indices, "File indices should be preserved"
assert "reward" in modified_dataset.meta.features

View File

@@ -95,7 +95,6 @@ def test_get_policy_and_config_classes(policy_name: str):
@pytest.mark.parametrize(
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
[
("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}),
("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}),
@@ -328,8 +327,6 @@ def test_multikey_construction(multikey: bool):
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.